diff --git a/.gitmodules b/.gitmodules index 0ebafd5bc..87cf0015e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -9,4 +9,4 @@ [submodule "parsers/onnx"] path = parsers/onnx url = https://github.com/onnx/onnx-tensorrt.git - branch = main + branch = release/10.6-GA diff --git a/CHANGELOG.md b/CHANGELOG.md index f1d4f86af..3bce1c8db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,37 @@ # TensorRT OSS Release Changelog +## 10.6.0 GA - 2024-11-05 +Key Feature and Updates: +- Demo Changes + - demoBERT: The use of `fcPlugin` in demoBERT has been removed. + - demoBERT: All TensorRT plugins now used in demoBERT (`CustomEmbLayerNormDynamic`, `CustomSkipLayerNormDynamic`, and `CustomQKVToContextDynamic`) now have versions that inherit from IPluginV3 interface classes. The user can opt-in to use these V3 plugins by specifying `--use-v3-plugins` to the builder scripts. + - Opting-in to use V3 plugins does not affect performance, I/O, or plugin attributes. + - There is a known issue in the V3 (version 4) of `CustomQKVToContextDynamic` plugin from TensorRT 10.6.0, causing an internal assertion error if either the batch or sequence dimensions differ at runtime from the ones used to serialize the engine. See the “known issues” section of the [TensorRT-10.6.0 release notes](https://docs.nvidia.com/deeplearning/tensorrt/release-notes/index.html#rel-10-6-0). + - For smoother migration, the default behavior is still using the deprecated `IPluginV2DynamicExt`-derived plugins, when the flag: `--use-v3-plugins` isn't specified in the builder scripts. The flag `--use-deprecated-plugins` was added as an explicit way to enforce the default behavior, and is mutually exclusive with `--use-v3-plugins`. + - demoDiffusion + - Introduced BF16 and FP8 support for the [Flux.1-dev](demo/Diffusion#generate-an-image-guided-by-a-text-prompt-using-flux) pipeline. + - Expanded FP8 support on Ada platforms. + - Enabled LoRA adapter compatibility for SDv1.5, SDv2.1, and SDXL pipelines using Diffusers version 0.30.3. + +- Sample Changes + - Added the Python sample [quickly_deployable_plugins](samples/python/quickly_deployable_plugins), which demonstrates quickly deployable Python-based plugin definitions (QDPs) in TensorRT. QDPs are a simple and intuitive decorator-based approach to defining TensorRT plugins, requiring drastically less code. + +- Plugin Changes + - The `fcPlugin` has been deprecated. Its functionality has been superseded by the [IMatrixMultiplyLayer](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_matrix_multiply_layer.html) that is natively provided by TensorRT. + - Migrated `IPluginV2`-descendent version 1 of `CustomEmbLayerNormDynamic`, to version 6, which implements `IPluginV3`. + - The newer versions preserve the attributes and I/O of the corresponding older plugin version. + - The older plugin versions are deprecated and will be removed in a future release. + +- Parser Changes + - Updated ONNX submodule version to 1.17.0. + - Fixed issue where conditional layers were incorrectly being added. + - Updated local function metadata to contain more information. + - Added support for parsing nodes with Quickly Deployable Plugins. + - Fixed handling of optional outputs. + +- Tool Updates + - ONNX-Graphsurgeon updated to version 0.5.3 + - Polygraphy updated to 0.49.14. ## 10.5.0 GA - 2024-09-30 Key Features and Updates: diff --git a/README.md b/README.md index bb2f2e413..247f86e20 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ You can skip the **Build** section to enjoy TensorRT with Python. To build the TensorRT-OSS components, you will first need the following software packages. **TensorRT GA build** -* TensorRT v10.5.0.18 +* TensorRT v10.6.0.26 * Available from direct download links listed below **System Packages** @@ -73,25 +73,25 @@ To build the TensorRT-OSS components, you will first need the following software If using the TensorRT OSS build container, TensorRT libraries are preinstalled under `/usr/lib/x86_64-linux-gnu` and you may skip this step. Else download and extract the TensorRT GA build from [NVIDIA Developer Zone](https://developer.nvidia.com) with the direct links below: - - [TensorRT 10.5.0.18 for CUDA 11.8, Linux x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-11.8.tar.gz) - - [TensorRT 10.5.0.18 for CUDA 12.6, Linux x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz) - - [TensorRT 10.5.0.18 for CUDA 11.8, Windows x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/zip/TensorRT-10.5.0.18.Windows.win10.cuda-11.8.zip) - - [TensorRT 10.5.0.18 for CUDA 12.6, Windows x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/zip/TensorRT-10.5.0.18.Windows.win10.cuda-12.6.zip) + - [TensorRT 10.6.0.26 for CUDA 11.8, Linux x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.6.0/tars/TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz) + - [TensorRT 10.6.0.26 for CUDA 12.6, Linux x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.6.0/tars/TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz) + - [TensorRT 10.6.0.26 for CUDA 11.8, Windows x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.6.0/zip/TensorRT-10.6.0.26.Windows.win10.cuda-11.8.zip) + - [TensorRT 10.6.0.26 for CUDA 12.6, Windows x86_64](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.6.0/zip/TensorRT-10.6.0.26.Windows.win10.cuda-12.6.zip) **Example: Ubuntu 20.04 on x86-64 with cuda-12.6** ```bash cd ~/Downloads - tar -xvzf TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz - export TRT_LIBPATH=`pwd`/TensorRT-10.5.0.18 + tar -xvzf TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz + export TRT_LIBPATH=`pwd`/TensorRT-10.6.0.26 ``` **Example: Windows on x86-64 with cuda-12.6** ```powershell - Expand-Archive -Path TensorRT-10.5.0.18.Windows.win10.cuda-12.6.zip - $env:TRT_LIBPATH="$pwd\TensorRT-10.5.0.18\lib" + Expand-Archive -Path TensorRT-10.6.0.26.Windows.win10.cuda-12.6.zip + $env:TRT_LIBPATH="$pwd\TensorRT-10.6.0.26\lib" ``` ## Setting Up The Build Environment diff --git a/VERSION b/VERSION index 5099d0963..eafccb088 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -10.5.0.18 +10.6.0.26 diff --git a/demo/BERT/README.md b/demo/BERT/README.md index 074848c31..5abc76faa 100755 --- a/demo/BERT/README.md +++ b/demo/BERT/README.md @@ -75,7 +75,7 @@ The following software version configuration has been tested: |Software|Version| |--------|-------| |Python|>=3.8| -|TensorRT|10.5.0.18| +|TensorRT|10.6.0.26| |CUDA|12.6| ## Setup @@ -122,7 +122,7 @@ This demo BERT application can be run within the TensorRT OSS build container. I bash scripts/download_model.sh ``` -**Note:** Since the datasets and checkpoints are stored in the directory mounted from the host, they do *not* need to be downloaded each time the container is launched. +**Note:** Since the datasets and checkpoints are stored in the directory mounted from the host, they do *not* need to be downloaded each time the container is launched. **Warning:** In the event of encountering an error message stating, "Missing API key and missing Email Authentication. This command requires an API key or authentication via browser login", the recommended steps for resolution are as follows: * Generate an API key by logging in https://ngc.nvidia.com/setup/api-key and copy the generated API key. @@ -153,11 +153,11 @@ Completing these steps should resolve the error you encountered and allow the co jupyter notebook --ip 0.0.0.0 inference.ipynb ``` Then, use your browser to open the link displayed. The link should look similar to: `http://127.0.0.1:8888/?token=` - + 6. Run inference with CUDA Graph support. A separate python `inference_c.py` script is provided to run inference with CUDA Graph support. This is necessary since CUDA Graph is only supported through CUDA C/C++ APIs, not pyCUDA. The `inference_c.py` script uses pybind11 to interface with C/C++ for CUDA graph capturing and launching. The cmdline interface is the same as `inference.py` except for an extra `--enable-graph` option. - + ```bash mkdir -p build; pushd build cmake .. -DPYTHON_EXECUTABLE=$(which python) @@ -167,11 +167,11 @@ Completing these steps should resolve the error you encountered and allow the co ``` A separate C/C++ inference benchmark executable `perf` (compiled from `perf.cpp`) is provided to run inference benchmarks with CUDA Graph. The cmdline interface is the same as `perf.py` except for an extra `--enable_graph` option. - + ```bash build/perf -e engines/bert_large_128.engine -b 1 -s 128 -w 100 -i 1000 --enable_graph ``` - + ### (Optional) Trying a different configuration @@ -220,6 +220,9 @@ The `infer_c/` folder contains all the necessary C/C++ files required for CUDA G To view the available parameters for each script, you can use the help flag (`-h`). +**Note:** In the builder scripts (`builder.py` and `builder_varseqlen.py`), the options `--use-deprecated-plugins` and `--use-v3-plugins` toggle the underlying implementation of the plugins used in demoBERT. They are mutually exclusive, and enabling either should not affect functionality, or performance. The `--use-deprecated-plugins` uses plugin versions that inherit from `IPluginV2DynamicExt`, while `--use-v3-plugins` uses plugin versions that inherit from `IPluginV3` classes. +If unspecified, `--use-deprecated-plugins` is used by default. + ### TensorRT inference process As mentioned in the [Quick Start Guide](#quick-start-guide), two options are provided for running inference: @@ -245,7 +248,7 @@ As mentioned in the [Quick Start Guide](#quick-start-guide), two options are pro **Xavier GPU** ```bash # Only supports SkipLayerNormPlugin running with INT8 I/O. Use -iln builder flag to enable. - mkdir -p engines && python3 builder.py -m models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/model.ckpt -o engines/bert_large_384_int8mix.engine -b 1 -s 384 --int8 --fp16 --strict -c models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1 --squad-json ./squad/train-v1.1.json -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt --calib-num 100 -iln + mkdir -p engines && python3 builder.py -m models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/model.ckpt -o engines/bert_large_384_int8mix.engine -b 1 -s 384 --int8 --fp16 --strict -c models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1 --squad-json ./squad/train-v1.1.json -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt --calib-num 100 -iln ``` **Volta GPU** @@ -278,13 +281,13 @@ As mentioned in the [Quick Start Guide](#quick-start-guide), two options are pro **Xavier GPU** ```bash # Only supports SkipLayerNormPlugin running with INT8 I/O. Use -iln builder flag to enable. - mkdir -p engines && python3 builder.py -o engines/bert_large_384_int8mix.engine -b 1 -s 384 --int8 --fp16 --strict -c models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1 -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt -x models/fine-tuned/bert_pyt_onnx_large_qa_squad11_amp_fake_quant_v1/bert_large_v1_1_fake_quant.onnx -iln + mkdir -p engines && python3 builder.py -o engines/bert_large_384_int8mix.engine -b 1 -s 384 --int8 --fp16 --strict -c models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1 -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt -x models/fine-tuned/bert_pyt_onnx_large_qa_squad11_amp_fake_quant_v1/bert_large_v1_1_fake_quant.onnx -iln ``` **Volta GPU** ```bash # No support for QKVToContextPlugin or SkipLayerNormPlugin running with INT8 I/O. Don't specify -imh or -iln in builder flags. - mkdir -p engines && python3 builder.py -o engines/bert_large_384_int8mix.engine -b 1 -s 384 --int8 --fp16 --strict -c models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1 -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt -x models/fine-tuned/bert_pyt_onnx_large_qa_squad11_amp_fake_quant_v1/bert_large_v1_1_fake_quant.onnx + mkdir -p engines && python3 builder.py -o engines/bert_large_384_int8mix.engine -b 1 -s 384 --int8 --fp16 --strict -c models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1 -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt -x models/fine-tuned/bert_pyt_onnx_large_qa_squad11_amp_fake_quant_v1/bert_large_v1_1_fake_quant.onnx ``` This will build and engine with a maximum batch size of 1 (`-b 1`) and sequence length of 384 (`-s 384`) using INT8 mixed precision computation where possible (`--int8 --fp16 --strict`). @@ -324,10 +327,10 @@ Note this is an experimental feature because we only support Xavier+ GPUs, also This will build and engine with a maximum batch size of 1 (`-b 1`) and sequence length of 256 (`-s 256`) using INT8 precision computation where possible (`--int8`). -3. Run inference +3. Run inference Evaluate the F1 score and exact match score using the squad dataset: - + ```bash python3 inference_varseqlen.py -e engines/bert_varseq_int8.engine -s 256 -sq ./squad/dev-v1.1.json -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt -o ./predictions.json python3 squad/evaluate-v1.1.py squad/dev-v1.1.json ./predictions.json 90 @@ -345,11 +348,11 @@ Note this is an experimental feature because we only support Xavier+ GPUs, also python3 perf_varseqlen.py -e engines/bert_varseq_int8.engine -b 1 -s 256 ``` - This will collect performance data run use batch size 1 (`-b 1`) and sequence length of 256 (`-s 256`). + This will collect performance data run use batch size 1 (`-b 1`) and sequence length of 256 (`-s 256`). 5. Collect performance data with CUDA graph enabled - We can use the same `inference_c.py` and `build/perf` to collect performance data with cuda graph enabled. The command line is the same as run without variable sequence length. + We can use the same `inference_c.py` and `build/perf` to collect performance data with cuda graph enabled. The command line is the same as run without variable sequence length. ### Sparsity with Quantization Aware Training diff --git a/demo/BERT/builder.py b/demo/BERT/builder.py index 90060ed2f..c2d2c0820 100755 --- a/demo/BERT/builder.py +++ b/demo/BERT/builder.py @@ -35,6 +35,10 @@ from builder_utils import WQKV, BQKV # Attention Keys from builder_utils import W_AOUT, B_AOUT, W_MID, B_MID, W_LOUT, B_LOUT # Transformer Keys from builder_utils import SQD_W, SQD_B # SQuAD Output Keys +from builder_utils import ( + create_plugin, + add_plugin_to_network, +) # Plugin Helper functions """ TensorRT Initialization @@ -51,13 +55,23 @@ trt.init_libnvinfer_plugins(TRT_LOGGER, "") plg_registry = trt.get_plugin_registry() -emln_plg_creator = plg_registry.get_plugin_creator("CustomEmbLayerNormPluginDynamic", "1", "") -qkv2_plg_creator = plg_registry.get_plugin_creator("CustomQKVToContextPluginDynamic", "1", "") -skln_plg_creator = plg_registry.get_plugin_creator("CustomSkipLayerNormPluginDynamic", "1", "") -fc_plg_creator = plg_registry.get_plugin_creator("CustomFCPluginDynamic", "1", "") + class BertConfig: - def __init__(self, bert_config_path, use_fp16, use_int8, use_strict, use_fc2_gemm, use_int8_skipln, use_int8_multihead, use_qat, use_sparsity, timing_cache): + def __init__( + self, + bert_config_path, + use_fp16, + use_int8, + use_strict, + use_fc2_gemm, + use_int8_skipln, + use_int8_multihead, + use_qat, + use_sparsity, + timing_cache, + use_deprecated_plugins=False, + ): with open(bert_config_path, "r") as f: data = json.load(f) self.num_attention_heads = data["num_attention_heads"] @@ -75,6 +89,8 @@ def __init__(self, bert_config_path, use_fp16, use_int8, use_strict, use_fc2_gem self.use_qat = use_qat self.use_sparsity = use_sparsity self.timing_cache = timing_cache + self.use_deprecated_plugins = use_deprecated_plugins + def set_tensor_name(tensor, prefix, name): tensor.name = prefix + name @@ -131,19 +147,26 @@ def attention_layer_opt(prefix, config, init_dict, network, input_tensor, imask) pf_dq_probs = trt.PluginField("dq_probs", np.array([dq_probs], np.float32), trt.PluginFieldType.FLOAT32) pfc = trt.PluginFieldCollection([pf_hidden_size, pf_num_heads, pf_has_mask, pf_type, pf_dq_probs]) else: - pfc = trt.PluginFieldCollection([pf_hidden_size, pf_num_heads, pf_has_mask, pf_type]) - qkv2ctx_plug = qkv2_plg_creator.create_plugin("qkv2ctx", pfc) + pfc = trt.PluginFieldCollection( + [pf_hidden_size, pf_num_heads, pf_has_mask, pf_type] + ) + qkv2ctx_plugin = create_plugin( + "qkv_to_context", plg_registry, pfc, use_deprecated_plugins=config.use_deprecated_plugins + ) qkv_in = [mult_all.get_output(0)] if has_mask: qkv_in.append(imask) - qkv2ctx = network.add_plugin_v2(qkv_in, qkv2ctx_plug) + + qkv2ctx_layer = add_plugin_to_network( + network, qkv2ctx_plugin, qkv_in, use_deprecated_plugins=config.use_deprecated_plugins + ) if config.use_qat: dr_ctx = init_dict[prefix + 'output_dense_input_amax'] - set_output_range(qkv2ctx, dr_ctx) - set_output_name(qkv2ctx, prefix, "context_layer") - return qkv2ctx + set_output_range(qkv2ctx_layer, dr_ctx) + set_output_name(qkv2ctx_layer, prefix, "context_layer") + return qkv2ctx_layer def skipln(prefix, config, init_dict, network, input_tensor, skip, bias=None): """ @@ -174,26 +197,15 @@ def skipln(prefix, config, init_dict, network, input_tensor, skip, bias=None): fields.append(pf_bias) pfc = trt.PluginFieldCollection(fields) - skipln_plug = skln_plg_creator.create_plugin("skipln", pfc) + skipln_plugin = create_plugin( + "skip_layer_norm", plg_registry, pfc, use_deprecated_plugins=config.use_deprecated_plugins + ) skipln_inputs = [input_tensor, skip] - layer = network.add_plugin_v2(skipln_inputs, skipln_plug) - return layer - -# Custom FC plugin is faster than native FC only on older architectures. -def use_custom_fc(): - cc = pycuda.autoinit.device.compute_capability() - return cc[0] * 10 + cc[1] <= 70 - -def custom_fc(config, network, input_tensor, out_dims, W): - pf_out_dims = trt.PluginField("out_dims", np.array([out_dims], dtype=np.int32), trt.PluginFieldType.INT32) - pf_W = trt.PluginField("W", W.numpy(), trt.PluginFieldType.FLOAT32) - pf_type = trt.PluginField("type_id", np.array([1 if config.use_fp16 else 0], np.int32), trt.PluginFieldType.INT32) - pfc = trt.PluginFieldCollection([pf_out_dims, pf_W, pf_type]) - fc_plugin = fc_plg_creator.create_plugin("fcplugin", pfc) - plug_inputs = [input_tensor] - out_dense = network.add_plugin_v2(plug_inputs, fc_plugin) - return out_dense + skipln_layer = add_plugin_to_network( + network, skipln_plugin, skipln_inputs, use_deprecated_plugins=config.use_deprecated_plugins + ) + return skipln_layer def transformer_layer_opt(prefix, config, init_dict, network, input_tensor, imask): """ @@ -214,22 +226,17 @@ def transformer_layer_opt(prefix, config, init_dict, network, input_tensor, imas # FC0 B_aout = init_dict[prefix + B_AOUT] - if not config.use_int8 and use_custom_fc(): - W_aoutT = init_dict[prefix + W_AOUT + "_notrans"] - attention_out_fc = custom_fc(config, network, attention_heads, hidden_size, W_aoutT) - else: - W_aout = init_dict[prefix + W_AOUT] - attention_out_fc = network.add_convolution_nd(attention_heads, hidden_size, (1, 1), W_aout, B_aout) - B_aout = None + W_aout = init_dict[prefix + W_AOUT] + attention_out_fc = network.add_convolution_nd(attention_heads, hidden_size, (1, 1), W_aout, B_aout) - if config.use_int8 and not config.use_int8_skipln: - attention_out_fc.set_output_type(0, trt.DataType.HALF if config.use_fp16 else trt.DataType.FLOAT) + if config.use_int8 and not config.use_int8_skipln: + attention_out_fc.set_output_type(0, trt.DataType.HALF if config.use_fp16 else trt.DataType.FLOAT) - if config.use_int8 and config.use_qat: - dr_fc_aout = init_dict[prefix + 'attention_output_add_local_input_quantizer_amax'] - set_output_range(attention_out_fc, dr_fc_aout) + if config.use_int8 and config.use_qat: + dr_fc_aout = init_dict[prefix + 'attention_output_add_local_input_quantizer_amax'] + set_output_range(attention_out_fc, dr_fc_aout) - skiplayer = skipln(prefix + "attention_output_layernorm_",config, init_dict, network, attention_out_fc.get_output(0), input_tensor, B_aout) + skiplayer = skipln(prefix + "attention_output_layernorm_",config, init_dict, network, attention_out_fc.get_output(0), input_tensor, bias=None) attention_ln = skiplayer.get_output(0) if config.use_qat: dr_skln1 = init_dict[prefix + 'intermediate_dense_input_amax'] @@ -271,24 +278,18 @@ def transformer_layer_opt(prefix, config, init_dict, network, input_tensor, imas # FC2 # Dense to hidden size B_lout = init_dict[prefix + B_LOUT] - prefer_conv = config.use_int8 and not config.use_fc2_gemm - if not prefer_conv and use_custom_fc(): - W_loutT = init_dict[prefix + W_LOUT + "_notrans"] - out_dense = custom_fc(config, network, intermediate_act, hidden_size, W_loutT) - else: - W_lout = init_dict[prefix + W_LOUT] - out_dense = network.add_convolution_nd(intermediate_act, hidden_size, (1, 1), W_lout, B_lout) - B_lout = None + W_lout = init_dict[prefix + W_LOUT] + out_dense = network.add_convolution_nd(intermediate_act, hidden_size, (1, 1), W_lout, B_lout) - if config.use_int8 and not config.use_int8_skipln: - out_dense.set_output_type(0, trt.DataType.HALF if config.use_fp16 else trt.DataType.FLOAT) + if config.use_int8 and not config.use_int8_skipln: + out_dense.set_output_type(0, trt.DataType.HALF if config.use_fp16 else trt.DataType.FLOAT) if config.use_qat: dr_fc_out = init_dict[prefix + 'output_add_local_input_quantizer_amax'] set_output_range(out_dense, dr_fc_out) set_output_name(out_dense, prefix + "output_", "dense") - out_layer = skipln(prefix + "output_layernorm_", config, init_dict, network, out_dense.get_output(0), attention_ln, B_lout) + out_layer = skipln(prefix + "output_layernorm_", config, init_dict, network, out_dense.get_output(0), attention_ln, bias=None) set_output_name(out_layer, prefix + "output_", "reshape") return out_layer @@ -366,8 +367,12 @@ def emb_layernorm(builder, network, config, weights_dict, builder_config, sequen output_fp16 = trt.PluginField("output_fp16", np.array([1 if config.use_fp16 else 0]).astype(np.int32), trt.PluginFieldType.INT32) mha_type = trt.PluginField("mha_type_id", np.array([get_mha_dtype(config)], np.int32), trt.PluginFieldType.INT32) - pfc = trt.PluginFieldCollection([wbeta, wgamma, wwordemb, wtokemb, wposemb, output_fp16, mha_type]) - fn = emln_plg_creator.create_plugin("embeddings", pfc) + pfc = trt.PluginFieldCollection( + [wbeta, wgamma, wwordemb, wtokemb, wposemb, output_fp16, mha_type] + ) + emln_plugin = create_plugin( + "emb_layer_norm", plg_registry, pfc, use_deprecated_plugins=config.use_deprecated_plugins + ) input_ids = network.add_shuffle(input_ids) input_ids.second_transpose = (1, 0) @@ -375,10 +380,14 @@ def emb_layernorm(builder, network, config, weights_dict, builder_config, sequen segment_ids.second_transpose = (1, 0) input_mask = network.add_shuffle(input_mask) input_mask.second_transpose = (1, 0) - inputs = [input_ids.get_output(0), - segment_ids.get_output(0), - input_mask.get_output(0)] - emb_layer = network.add_plugin_v2(inputs, fn) + inputs = [ + input_ids.get_output(0), + segment_ids.get_output(0), + input_mask.get_output(0), + ] + emb_layer = add_plugin_to_network( + network, emln_plugin, inputs, use_deprecated_plugins=config.use_deprecated_plugins + ) if config.use_qat: set_output_range(emb_layer, 1, 1) @@ -490,30 +499,151 @@ def generate_calibration_cache(sequence_lengths, workspace_size, config, weights config.is_calib_mode = False def main(): - parser = argparse.ArgumentParser(description="TensorRT BERT Sample", formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("-m", "--ckpt", required=False, - help="The checkpoint file basename, e.g.: basename(model.ckpt-766908.data-00000-of-00001) is model.ckpt-766908") - parser.add_argument("-x", "--onnx", required=False, help="The ONNX model file path.") - parser.add_argument("-pt", "--pytorch", required=False, help="The PyTorch checkpoint file path.") - parser.add_argument("-o", "--output", required=True, default="bert_base_384.engine", help="The bert engine file, ex bert.engine") - parser.add_argument("-b", "--batch-size", default=[], action="append", help="Batch size(s) to optimize for. The engine will be usable with any batch size below this, but may not be optimal for smaller sizes. Can be specified multiple times to optimize for more than one batch size.", type=int) - parser.add_argument("-s", "--sequence-length", default=[], action="append", help="Sequence length of the BERT model", type=int) - parser.add_argument("-c", "--config-dir", required=True, - help="The folder containing the bert_config.json, which can be downloaded e.g. from https://github.com/google-research/bert#pre-trained-models or by running download_models.py in dle/TensorFlow/LanguageModeling/BERT/data/pretrained_models_google") - parser.add_argument("-f", "--fp16", action="store_true", help="Indicates that inference should be run in FP16 precision", required=False) - parser.add_argument("-i", "--int8", action="store_true", help="Indicates that inference should be run in INT8 precision", required=False) - parser.add_argument("-t", "--strict", action="store_true", help="Indicates that inference should be run in strict precision mode", required=False) - parser.add_argument("-w", "--workspace-size", default=2500, help="Workspace size in MiB for building the BERT engine", type=int) - parser.add_argument("-j", "--squad-json", default="squad/dev-v1.1.json", help="squad json dataset used for int8 calibration", required=False) - parser.add_argument("-v", "--vocab-file", default="./pre-trained_model/uncased_L-24_H-1024_A-16/vocab.txt", help="Path to file containing entire understandable vocab", required=False) - parser.add_argument("-n", "--calib-num", default=100, help="calibration batch numbers", type=int) - parser.add_argument("-p", "--calib-path", help="calibration cache path", required=False) - parser.add_argument("-g", "--force-fc2-gemm", action="store_true", help="Force use gemm to implement FC2 layer", required=False) - parser.add_argument("-iln", "--force-int8-skipln", action="store_true", help="Run skip layernorm with INT8 (FP32 or FP16 by default) inputs and output", required=False) - parser.add_argument("-imh", "--force-int8-multihead", action="store_true", help="Run multi-head attention with INT8 (FP32 or FP16 by default) input and output", required=False) - parser.add_argument("-sp", "--sparse", action="store_true", help="Indicates that model is sparse", required=False) - parser.add_argument("-tcf", "--timing-cache-file", help="Path to tensorrt build timeing cache file, only available for tensorrt 8.0 and later", required=False) - parser.add_argument("--verbose", action="store_true", help="Turn on verbose logger and set profiling verbosity to DETAILED", required=False) + parser = argparse.ArgumentParser( + description="TensorRT BERT Sample", + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "-m", + "--ckpt", + required=False, + help="The checkpoint file basename, e.g.: basename(model.ckpt-766908.data-00000-of-00001) is model.ckpt-766908 (default: None)", + ) + parser.add_argument( + "-x", "--onnx", required=False, help="The ONNX model file path. (default: None)" + ) + parser.add_argument( + "-pt", "--pytorch", required=False, help="The PyTorch checkpoint file path. (default: None)" + ) + parser.add_argument( + "-o", + "--output", + required=True, + default="bert_base_384.engine", + help="The bert engine file, ex bert.engine (default: bert_base_384.engine)", + ) + parser.add_argument( + "-b", + "--batch-size", + default=[], + action="append", + help="Batch size(s) to optimize for. The engine will be usable with any batch size below this, but may not be optimal for smaller sizes. Can be specified multiple times to optimize for more than one batch size. (default: [1])", + type=int, + ) + parser.add_argument( + "-s", + "--sequence-length", + default=[], + action="append", + help="Sequence length of the BERT model (default: [128])", + type=int, + ) + parser.add_argument( + "-c", + "--config-dir", + required=True, + help="The folder containing the bert_config.json, which can be downloaded e.g. from https://github.com/google-research/bert#pre-trained-models or by running download_models.py in dle/TensorFlow/LanguageModeling/BERT/data/pretrained_models_google", + ) + parser.add_argument( + "-f", + "--fp16", + action="store_true", + help="Indicates that inference should be run in FP16 precision (default: false)", + required=False, + ) + parser.add_argument( + "-i", + "--int8", + action="store_true", + help="Indicates that inference should be run in INT8 precision (default: false)", + required=False, + ) + parser.add_argument( + "-t", + "--strict", + action="store_true", + help="Indicates that inference should be run in strict precision mode (default: false)", + required=False, + ) + parser.add_argument( + "-w", + "--workspace-size", + default=2500, + help="Workspace size in MiB for building the BERT engine (default: 2500)", + type=int, + ) + parser.add_argument( + "-j", + "--squad-json", + default="squad/dev-v1.1.json", + help="squad json dataset used for int8 calibration (default: squad/dev-v1.1.json)", + required=False, + ) + parser.add_argument( + "-v", + "--vocab-file", + default="./pre-trained_model/uncased_L-24_H-1024_A-16/vocab.txt", + help="Path to file containing entire understandable vocab (default: ./pre-trained_model/uncased_L-24_H-1024_A-16/vocab.txt)", + required=False, + ) + parser.add_argument( + "-n", "--calib-num", default=100, help="calibration batch numbers (default: 100)", type=int + ) + parser.add_argument( + "-p", "--calib-path", help="calibration cache path (default: None)", required=False + ) + parser.add_argument( + "-g", + "--force-fc2-gemm", + action="store_true", + help="Force use gemm to implement FC2 layer (default: false)", + required=False, + ) + parser.add_argument( + "-iln", + "--force-int8-skipln", + action="store_true", + help="Run skip layernorm with INT8 (FP32 or FP16 by default) inputs and output (default: false)", + required=False, + ) + parser.add_argument( + "-imh", + "--force-int8-multihead", + action="store_true", + help="Run multi-head attention with INT8 (FP32 or FP16 by default) input and output (default: false)", + required=False, + ) + parser.add_argument( + "-sp", + "--sparse", + action="store_true", + help="Indicates that model is sparse (default: false)", + required=False, + ) + parser.add_argument( + "-tcf", + "--timing-cache-file", + help="Path to tensorrt build timeing cache file, only available for tensorrt 8.0 and later (default: None)", + required=False, + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Turn on verbose logger and set profiling verbosity to DETAILED (default: false)", + required=False, + ) + + plugin_group = parser.add_mutually_exclusive_group(required=False) + plugin_group.add_argument('--use-v3-plugins', + dest='use_deprecated_plugins', + action='store_false', + help="Use plugins implementing the IPluginV3 interface wherever TensorRT plugins are used. Cannot be used with --use-deprecated-plugins. Enabling this option should not affect functionality or performance. (default: false)") + plugin_group.add_argument('--use-deprecated-plugins', + dest='use_deprecated_plugins', + action='store_true', + help="Use deprecated plugins implementing the IPluginV2 interface wherever TensorRT plugins are used (instead of updated plugins implementing the IPluginV3 interface). Cannot be used with --use-v3-plugins. Disabling this option should not affect functionality or performance. (default: true)") + + parser.set_defaults(use_deprecated_plugins=True) args, _ = parser.parse_known_args() args.batch_size = args.batch_size or [1] @@ -531,7 +661,19 @@ def main(): bert_config_path = os.path.join(args.config_dir, "bert_config.json") TRT_LOGGER.log(TRT_LOGGER.INFO, "Using configuration file: {:}".format(bert_config_path)) - config = BertConfig(bert_config_path, args.fp16, args.int8, args.strict, args.force_fc2_gemm, args.force_int8_skipln, args.force_int8_multihead, args.int8 and args.onnx != None, args.sparse, args.timing_cache_file) + config = BertConfig( + bert_config_path, + args.fp16, + args.int8, + args.strict, + args.force_fc2_gemm, + args.force_int8_skipln, + args.force_int8_multihead, + args.int8 and args.onnx != None, + args.sparse, + args.timing_cache_file, + args.use_deprecated_plugins, + ) if args.calib_path != None: calib_cache = args.calib_path diff --git a/demo/BERT/builder_utils.py b/demo/BERT/builder_utils.py index abf0f5141..ed7f74487 100644 --- a/demo/BERT/builder_utils.py +++ b/demo/BERT/builder_utils.py @@ -313,3 +313,84 @@ def load_megatron_pickle_weights(path, config): TRT_LOGGER.log(TRT_LOGGER.INFO, "Found {:} entries in weight map".format(len(weight_dict))) return weight_dict + + +""" +Common Plugin Helper/Wrapper Functions +""" +BERT_PLUGINS_INFO_MAP = { + # MHA variants + "qkv_to_context": { + "IPluginV2_version": "1", + "IPluginV3_version": "4", + "trt_plugin_name": "CustomQKVToContextPluginDynamic", + }, + "qkv_to_context_varseqlen": { + "IPluginV2_version": "2", + "IPluginV3_version": "5", + "trt_plugin_name": "CustomQKVToContextPluginDynamic", + }, + "qkv_to_context_interleaved": { + "IPluginV2_version": "3", + "IPluginV3_version": "6", + "trt_plugin_name": "CustomQKVToContextPluginDynamic", + }, + # skipLayernorm variants + "skip_layer_norm": { + "IPluginV2_version": "1", + "IPluginV3_version": "5", + "trt_plugin_name": "CustomSkipLayerNormPluginDynamic", + }, + "skip_layer_norm_varseqlen": { + "IPluginV2_version": "2", + "IPluginV3_version": "6", + "trt_plugin_name": "CustomSkipLayerNormPluginDynamic", + }, + "skip_layer_norm_huggingface": { + "IPluginV2_version": "3", + "IPluginV3_version": "7", + "trt_plugin_name": "CustomSkipLayerNormPluginDynamic", + }, + "skip_layer_norm_megatron": { + "IPluginV2_version": "4", + "IPluginV3_version": "8", + "trt_plugin_name": "CustomSkipLayerNormPluginDynamic", + }, + # embLayernorm variants + "emb_layer_norm": { + "IPluginV2_version": "1", + "IPluginV3_version": "6", + "trt_plugin_name": "CustomEmbLayerNormPluginDynamic", + }, + "emb_layer_norm_huggingface": { + "IPluginV2_version": "2", + "IPluginV3_version": "4", + "trt_plugin_name": "CustomEmbLayerNormPluginDynamic", + }, + "emb_layer_norm_megatron": { + "IPluginV2_version": "3", + "IPluginV3_version": "5", + "trt_plugin_name": "CustomEmbLayerNormPluginDynamic", + }, +} + + +def create_plugin(layer_name, plg_registry, pfc, use_deprecated_plugins=False): + plg_trt_name = BERT_PLUGINS_INFO_MAP[layer_name]["trt_plugin_name"] + plg_version = BERT_PLUGINS_INFO_MAP[layer_name][ + ("IPluginV2_version" if use_deprecated_plugins else "IPluginV3_version") + ] + plg_namespace = "" + + creator = plg_registry.get_creator(plg_trt_name, plg_version, plg_namespace) + if use_deprecated_plugins: + return creator.create_plugin(layer_name, pfc) + else: + return creator.create_plugin(layer_name, pfc, trt.TensorRTPhase.BUILD) + + +def add_plugin_to_network(network, plugin, inputs, use_deprecated_plugins=False): + if use_deprecated_plugins: + return network.add_plugin_v2(inputs, plugin) + else: + return network.add_plugin_v3(inputs, [], plugin) diff --git a/demo/BERT/builder_varseqlen.py b/demo/BERT/builder_varseqlen.py index 33dd50607..b7328cd3e 100755 --- a/demo/BERT/builder_varseqlen.py +++ b/demo/BERT/builder_varseqlen.py @@ -34,6 +34,11 @@ from builder_utils import WQKV, BQKV # Attention Keys from builder_utils import W_AOUT, B_AOUT, W_MID, B_MID, W_LOUT, B_LOUT # Transformer Keys from builder_utils import SQD_W, SQD_B # SQuAD Output Keys +from builder_utils import ( + create_plugin, + add_plugin_to_network, +) # Plugin Helper functions + """ TensorRT Initialization @@ -50,19 +55,21 @@ trt.init_libnvinfer_plugins(TRT_LOGGER, "") plg_registry = trt.get_plugin_registry() -emln_plg_creator2 = plg_registry.get_plugin_creator("CustomEmbLayerNormPluginDynamic", "2", "") -mha_plg_creator2 = plg_registry.get_plugin_creator("CustomQKVToContextPluginDynamic", "2", "") -skln_plg_creator2 = plg_registry.get_plugin_creator("CustomSkipLayerNormPluginDynamic", "2", "") - -mha_plg_creator3 = plg_registry.get_plugin_creator("CustomQKVToContextPluginDynamic", "3", "") -skln_plg_creator3 = plg_registry.get_plugin_creator("CustomSkipLayerNormPluginDynamic", "3", "") -# Megatron Plugins -emln_plg_creator3 = plg_registry.get_plugin_creator("CustomEmbLayerNormPluginDynamic", "3", "") -skln_plg_creator4 = plg_registry.get_plugin_creator("CustomSkipLayerNormPluginDynamic", "4", "") class BertConfig: - def __init__(self, bert_config_path, use_fp16, use_int8, use_qat, interleaved, timing_cache, use_sparsity, use_megatron): + def __init__( + self, + bert_config_path, + use_fp16, + use_int8, + use_qat, + interleaved, + timing_cache, + use_sparsity, + use_megatron, + use_deprecated_plugins=False, + ): with open(bert_config_path, "r") as f: data = json.load(f) self.num_attention_heads = data["num_attention_heads"] @@ -77,6 +84,7 @@ def __init__(self, bert_config_path, use_fp16, use_int8, use_qat, interleaved, t self.timing_cache = timing_cache self.use_sparsity = use_sparsity self.use_megatron = use_megatron + self.use_deprecated_plugins = use_deprecated_plugins def get_trt_dtype(self): dtype = trt.float32 @@ -137,17 +145,30 @@ def attention_layer_opt(prefix, config, init_dict, network, input_tensor, mask_i if config.use_int8 and config.interleaved: pfc = trt.PluginFieldCollection(fields) - qkv2ctx_plug = mha_plg_creator3.create_plugin("qkv2ctx", pfc) + qkv2ctx_plug = create_plugin( + "qkv_to_context_interleaved", + plg_registry, + pfc, + use_deprecated_plugins=config.use_deprecated_plugins, + ) qkv_in = [mult_all.get_output(0), cu_seqlens, max_seqlen] else: fields.append(pf_has_mask) fields.append(pf_type) fields.append(pf_var_seqlen) pfc = trt.PluginFieldCollection(fields) - qkv2ctx_plug = mha_plg_creator2.create_plugin("qkv2ctx", pfc) + qkv2ctx_plug = create_plugin( + "qkv_to_context_varseqlen", + plg_registry, + pfc, + use_deprecated_plugins=config.use_deprecated_plugins, + ) qkv_in = [mult_all.get_output(0), mask_idx, cu_seqlens, max_seqlen] - qkv2ctx = network.add_plugin_v2(qkv_in, qkv2ctx_plug) - qkv2ctx.name = prefix + 'qkv_to_ctx' + qkv2ctx = add_plugin_to_network( + network, qkv2ctx_plug, qkv_in, use_deprecated_plugins=config.use_deprecated_plugins + ) + + qkv2ctx.name = prefix + "qkv_to_ctx" if config.use_qat: dr_ctx = init_dict[prefix + 'output_dense_input_amax'] @@ -171,14 +192,27 @@ def skipln(prefix, config, init_dict, network, input_tensor, skip, is_last_skipl if config.use_int8 and config.interleaved: pfc = trt.PluginFieldCollection([pf_beta, pf_gamma]) - creator = skln_plg_creator3 if not config.use_megatron or is_last_skipln else skln_plg_creator4 - skipln_plug = creator.create_plugin("skipln", pfc) + variant_name = ( + "skip_layer_norm_huggingface" + if not config.use_megatron or is_last_skipln + else "skip_layer_norm_megatron" + ) + skipln_plug = create_plugin( + variant_name, plg_registry, pfc, use_deprecated_plugins=config.use_deprecated_plugins + ) else: pfc = trt.PluginFieldCollection([pf_ld, pf_beta, pf_gamma, pf_type]) - skipln_plug = skln_plg_creator2.create_plugin("skipln", pfc) + skipln_plug = create_plugin( + "skip_layer_norm_varseqlen", + plg_registry, + pfc, + use_deprecated_plugins=config.use_deprecated_plugins, + ) skipln_inputs = [input_tensor, skip] - layer = network.add_plugin_v2(skipln_inputs, skipln_plug) + layer = add_plugin_to_network( + network, skipln_plug, skipln_inputs, use_deprecated_plugins=config.use_deprecated_plugins + ) return layer def transformer_layer_opt(prefix, config, init_dict, network, input_tensor, residual, mask_idx, cu_seqlens, max_seqlen): @@ -361,11 +395,23 @@ def emb_layernorm(builder, network, config, weights_dict, builder_config, max_se wposemb = trt.PluginField("bert_embeddings_position_embeddings", weights_dict["bert_embeddings_position_embeddings"].numpy(), trt.PluginFieldType.FLOAT32) output_fp16 = trt.PluginField("output_fp16", np.array([1 if config.use_fp16 or config.use_int8 else 0]).astype(np.int32), trt.PluginFieldType.INT32) - pfc = trt.PluginFieldCollection([wbeta, wgamma, wwordemb, wtokemb, wposemb, output_fp16]) - fn = (emln_plg_creator3 if config.use_megatron else emln_plg_creator2).create_plugin("embeddings", pfc) + pfc = trt.PluginFieldCollection( + [wbeta, wgamma, wwordemb, wtokemb, wposemb, output_fp16] + ) + variant_name = ( + "emb_layer_norm_megatron" + if config.use_megatron + else "emb_layer_norm_huggingface" + ) + fn = create_plugin( + variant_name, plg_registry, pfc, use_deprecated_plugins=config.use_deprecated_plugins + ) inputs = [input_ids, segment_ids, cu_seqlens, max_seqlen] - emb_layer = network.add_plugin_v2(inputs, fn) + + emb_layer = add_plugin_to_network( + network, fn, inputs, use_deprecated_plugins=config.use_deprecated_plugins + ) if config.use_int8 and config.use_qat: dr_input = weights_dict['l0_attention_self_query_input_amax'] @@ -458,29 +504,140 @@ def build_engine(batch_sizes, workspace_size, sequence_length, config, weights_d return serialized_engine def main(): - parser = argparse.ArgumentParser(description="TensorRT BERT Sample", formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("-m", "--ckpt", required=False, - help="The checkpoint file basename, e.g.: basename(model.ckpt-766908.data-00000-of-00001) is model.ckpt-766908") - parser.add_argument("-x", "--onnx", required=False, help="The ONNX model file path.") - parser.add_argument("-pt", "--pytorch", required=False, help="The PyTorch checkpoint file path.") - parser.add_argument("-pkl", "--pickle", required=False, help="The Pickle weights dictionary file path for the Megatron variant of BERT.") - parser.add_argument("-o", "--output", required=True, default="bert_base_384.engine", help="The bert engine file, ex bert.engine") - parser.add_argument("-b", "--max-batch-size", default=[], action="append", help="Max batch size. The engine will be usable with any input with (batch-size * sequence-length) below (max-batch-size * max-sequence-length). Can be specified multiple times to build optimization profiles for more than one batch size.", type=int) - parser.add_argument("-s", "--max-sequence-length", default=128, help="Max sequence length of the BERT model. The engine will be usable with any input with (batch-size * sequence-length) below (max-batch-size * max-sequence-length).", type=int) - parser.add_argument("-c", "--config-dir", required=True, - help="The folder containing the bert_config.json, which can be downloaded e.g. from https://github.com/google-research/bert#pre-trained-models or by running download_models.py in dle/TensorFlow/LanguageModeling/BERT/data/pretrained_models_google") - parser.add_argument("-f", "--fp16", action="store_true", help="Indicates that inference should be run in FP16 precision", required=False) - parser.add_argument("-i", "--int8", action="store_true", help="Indicates that inference should be run in INT8 precision", required=False) - parser.add_argument("-w", "--workspace-size", default=2500, help="Workspace size in MiB for building the BERT engine", type=int) - parser.add_argument("-j", "--squad-json", default="squad/dev-v1.1.json", help="squad json dataset used for int8 calibration", required=False) - parser.add_argument("-v", "--vocab-file", default="./pre-trained_model/uncased_L-24_H-1024_A-16/vocab.txt", help="Path to file containing entire understandable vocab", required=False) - parser.add_argument("-n", "--calib-num", default=100, help="calibration batch numbers", type=int) - parser.add_argument("-p", "--calib-path", help="calibration cache path", required=False) - parser.add_argument("-il", "--interleaved", action="store_true", help="use interleaved format, only valid in INT8 precision", required=False) - parser.add_argument("-tcf", "--timing-cache-file", help="Path to tensorrt build timeing cache file, only available for tensorrt 8.0 and later", required=False) - parser.add_argument("-sp", "--sparse", action="store_true", help="Indicates that model is sparse", required=False) - parser.add_argument("--megatron", action="store_true", help="Indicates that model is the Megatron-style architecture", required=False) - parser.add_argument("--verbose", action="store_true", help="Turn on verbose logger and set profiling verbosity to verbose", required=False) + parser = argparse.ArgumentParser( + description="TensorRT BERT Sample", + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "-m", + "--ckpt", + required=False, + help="The checkpoint file basename, e.g.: basename(model.ckpt-766908.data-00000-of-00001) is model.ckpt-766908 (default: None)", + ) + parser.add_argument( + "-x", "--onnx", required=False, help="The ONNX model file path. (default: None)" + ) + parser.add_argument( + "-pt", "--pytorch", required=False, help="The PyTorch checkpoint file path. (default: None)" + ) + parser.add_argument( + "-pkl", + "--pickle", + required=False, + help="The Pickle weights dictionary file path for the Megatron variant of BERT. (default: None)", + ) + parser.add_argument( + "-o", + "--output", + required=True, + default="bert_base_384.engine", + help="The bert engine file, ex bert.engine (default: bert_base_384.engine)", + ) + parser.add_argument( + "-b", + "--max-batch-size", + default=[], + action="append", + help="Max batch size. The engine will be usable with any input with (batch-size * sequence-length) below (max-batch-size * max-sequence-length). Can be specified multiple times to build optimization profiles for more than one batch size. (default: [1])", + type=int, + ) + parser.add_argument( + "-s", + "--max-sequence-length", + default=128, + help="Max sequence length of the BERT model. The engine will be usable with any input with (batch-size * sequence-length) below (max-batch-size * max-sequence-length). (default: 128)", + type=int, + ) + parser.add_argument( + "-c", + "--config-dir", + required=True, + help="The folder containing the bert_config.json, which can be downloaded e.g. from https://github.com/google-research/bert#pre-trained-models or by running download_models.py in dle/TensorFlow/LanguageModeling/BERT/data/pretrained_models_google", + ) + parser.add_argument( + "-f", + "--fp16", + action="store_true", + help="Indicates that inference should be run in FP16 precision (default: false)", + required=False, + ) + parser.add_argument( + "-i", + "--int8", + action="store_true", + help="Indicates that inference should be run in INT8 precision (default: false)", + required=False, + ) + parser.add_argument( + "-w", + "--workspace-size", + default=2500, + help="Workspace size in MiB for building the BERT engine (default: 2500)", + type=int, + ) + parser.add_argument( + "-j", + "--squad-json", + default="squad/dev-v1.1.json", + help="squad json dataset used for int8 calibration (default: squad/dev-v1.1.json)", + required=False, + ) + parser.add_argument( + "-v", + "--vocab-file", + default="./pre-trained_model/uncased_L-24_H-1024_A-16/vocab.txt", + help="Path to file containing entire understandable vocab (default: ./pre-trained_model/uncased_L-24_H-1024_A-16/vocab.txt)", + required=False, + ) + parser.add_argument( + "-n", "--calib-num", default=100, help="calibration batch numbers (default: 100)", type=int + ) + parser.add_argument( + "-p", "--calib-path", help="calibration cache path (default: None)", required=False + ) + parser.add_argument( + "-il", + "--interleaved", + action="store_true", + help="use interleaved format, only valid in INT8 precision (default: false)", + required=False, + ) + parser.add_argument( + "-tcf", + "--timing-cache-file", + help="Path to tensorrt build timeing cache file, only available for tensorrt 8.0 and later (default: None)", + required=False, + ) + parser.add_argument( + "-sp", + "--sparse", + action="store_true", + help="Indicates that model is sparse (default: false)", + required=False, + ) + parser.add_argument( + "--megatron", + action="store_true", + help="Indicates that model is the Megatron-style architecture (default: false)", + required=False, + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Turn on verbose logger and set profiling verbosity to verbose (default: false)", + required=False, + ) + + plugin_group = parser.add_mutually_exclusive_group(required=False) + plugin_group.add_argument('--use-v3-plugins', + dest='use_deprecated_plugins', + action='store_false', + help="Use plugins implementing the IPluginV3 interface wherever TensorRT plugins are used. Cannot be used with --use-deprecated-plugins. Enabling this option should not affect functionality or performance. (default: false)") + plugin_group.add_argument('--use-deprecated-plugins', + dest='use_deprecated_plugins', + action='store_true', + help="Use deprecated plugins implementing the IPluginV2 interface wherever TensorRT plugins are used (instead of updated plugins implementing the IPluginV3 interface). Cannot be used with --use-v3-plugins. Disabling this option should not affect functionality or performance. (default: true)") + parser.set_defaults(use_deprecated_plugins=True) args, _ = parser.parse_known_args() args.max_batch_size = args.max_batch_size or [1] @@ -501,7 +658,17 @@ def main(): bert_config_path = os.path.join(args.config_dir, "bert_config.json") TRT_LOGGER.log(TRT_LOGGER.INFO, "Using configuration file: {:}".format(bert_config_path)) - config = BertConfig(bert_config_path, args.fp16, args.int8, args.int8 and (args.onnx or args.pytorch or args.pickle), args.interleaved, args.timing_cache_file, args.sparse, args.megatron) + config = BertConfig( + bert_config_path, + args.fp16, + args.int8, + args.int8 and (args.onnx or args.pytorch or args.pickle), + args.interleaved, + args.timing_cache_file, + args.sparse, + args.megatron, + args.use_deprecated_plugins, + ) if args.calib_path != None: calib_cache = args.calib_path diff --git a/demo/Diffusion/README.md b/demo/Diffusion/README.md index 778767a55..974bad5b3 100755 --- a/demo/Diffusion/README.md +++ b/demo/Diffusion/README.md @@ -48,7 +48,7 @@ onnx 1.15.0 onnx-graphsurgeon 0.5.2 onnxruntime 1.16.3 polygraphy 0.49.9 -tensorrt 10.5.0.18 +tensorrt 10.6.0.26 tokenizers 0.13.3 torch 2.2.0 transformers 4.42.2 @@ -148,7 +148,7 @@ python3 demo_txt2img_xl.py "a photo of an astronaut riding a horse on mars" --hf ### Generate an image guided by a text prompt, and using specified LoRA model weight updates ```bash -python3 demo_txt2img_xl.py "Picture of a rustic Italian village with Olive trees and mountains" --version=xl-1.0 --lora-path "ostris/crayon_style_lora_sdxl" "ostris/watercolor_style_lora_sdxl" --lora-scale 0.3 0.7 --onnx-dir onnx-sdxl-lora --engine-dir engine-sdxl-lora --build-enable-refit +python3 demo_txt2img_xl.py "Picture of a rustic Italian village with Olive trees and mountains" --version=xl-1.0 --lora-path "ostris/crayon_style_lora_sdxl" "ostris/watercolor_style_lora_sdxl" --lora-weight 0.3 0.7 --onnx-dir onnx-sdxl-lora --engine-dir engine-sdxl-lora --build-enable-refit ``` ### Faster Text-to-image using SDXL INT8 & FP8 quantization using ModelOpt @@ -174,7 +174,7 @@ For step-by-step tutorials to run INT8 & FP8 inference on stable diffusion model [LCM-LoRA](https://arxiv.org/abs/2311.05556) produces good quality images in 4 to 8 denoising steps instead of 30+ needed base model. Note that we use LCM scheduler and disable classifier-free-guidance by setting `--guidance-scale` to 0. LoRA weights are fused into the ONNX and finalized TensorRT plan files in this example. ```bash -python3 demo_txt2img_xl.py "Einstein" --version xl-1.0 --lora-path "latent-consistency/lcm-lora-sdxl" --lora-scale 1.0 --onnx-dir onnx-sdxl-lcm-nocfg --engine-dir engine-sdxl-lcm-nocfg --denoising-steps 4 --scheduler LCM --guidance-scale 0.0 +python3 demo_txt2img_xl.py "Einstein" --version xl-1.0 --lora-path "latent-consistency/lcm-lora-sdxl" --lora-weight 1.0 --onnx-dir onnx-sdxl-lcm-nocfg --engine-dir engine-sdxl-lcm-nocfg --denoising-steps 4 --scheduler LCM --guidance-scale 0.0 ``` ### Faster Text-to-Image using SDXL Turbo Even faster image generation than LCM, producing coherent images in just 1 step. Note: SDXL Turbo works best for 512x512 resolution, EulerA scheduler and classifier-free-guidance disabled. @@ -215,6 +215,21 @@ SVD-XT-1.1 (25 frames at resolution 576x1024) python3 demo_img2vid.py --version svd-xt-1.1 --onnx-dir onnx-svd-xt-1-1 --engine-dir engine-svd-xt-1-1 --hf-token=$HF_TOKEN ``` + +Run the command below to generate a video in FP8. + +```bash +python3 demo_img2vid.py --version svd-xt-1.1 --onnx-dir onnx-svd-xt-1-1 --engine-dir engine-svd-xt-1-1 --hf-token=$HF_TOKEN --fp8 +``` + +> NOTE: There is a bug in HuggingFace, you can workaround with following this [PR](https://github.com/huggingface/diffusers/pull/6562/files) + +``` +if torch.is_tensor(num_frames): + num_frames = num_frames.item() +emb = emb.repeat_interleave(num_frames, dim=0) +``` + You may also specify a custom conditioning image using `--input-image`: ```bash python3 demo_img2vid.py --version svd-xt-1.1 --onnx-dir onnx-svd-xt-1-1 --engine-dir engine-svd-xt-1-1 --input-image https://www.hdcarwallpapers.com/walls/2018_chevrolet_camaro_zl1_nascar_race_car_2-HD.jpg --hf-token=$HF_TOKEN @@ -246,6 +261,18 @@ python3 demo_stable_cascade.py --onnx-opset=16 "Anthropomorphic cat dressed as a python3 demo_txt2img_flux.py "a beautiful photograph of Mt. Fuji during cherry blossom" --hf-token=$HF_TOKEN ``` +Run the below command to generate an image with FLUX in BF16. + +```bash +python3 demo_txt2img_flux.py "a beautiful photograph of Mt. Fuji during cherry blossom" --hf-token=$HF_TOKEN --bf16 +``` + +Run the below command to generate an image with FLUX in FP8. (FP8 is only supppoted on Hopper.) + +```bash +python3 demo_txt2img_flux.py "a beautiful photograph of Mt. Fuji during cherry blossom" --hf-token=$HF_TOKEN --fp8 +``` + NOTE: Running the Flux pipeline requires 80GB of GPU memory or higher ## Configuration options @@ -254,8 +281,4 @@ NOTE: Running the Flux pipeline requires 80GB of GPU memory or higher - Specify new directories for storing onnx and engine files when switching between versions, LoRAs, ControlNets, etc. This can be done using `--onnx-dir ` and `--engine-dir `. - Inference performance can be improved by enabling [CUDA graphs](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-graphs) using `--use-cuda-graph`. Enabling CUDA graphs requires fixed input shapes, so this flag must be combined with `--build-static-batch` and cannot be combined with `--build-dynamic-shape`. -## Known Issues -- LoRA adapter functionality is compatible with diffusers version 0.26.3. To run the LoRA pipeline, we recommend installing this specific version. However, the Stable Cascade pipeline requires diffusers version 0.29.2 or higher and will not be compatible if diffusers is downgraded. - - diff --git a/demo/Diffusion/calibration-images/rocket.png b/demo/Diffusion/calibration-images/rocket.png new file mode 100644 index 000000000..3f6fef6fa Binary files /dev/null and b/demo/Diffusion/calibration-images/rocket.png differ diff --git a/demo/Diffusion/demo_img2vid.py b/demo/Diffusion/demo_img2vid.py index 7d7b2b048..197d39050 100644 --- a/demo/Diffusion/demo_img2vid.py +++ b/demo/Diffusion/demo_img2vid.py @@ -63,6 +63,19 @@ def process_pipeline_args(args): if not args.build_static_batch or args.build_dynamic_shape: raise ValueError(f"Dynamic shapes not supported. Do not specify `--build-dynamic-shape`") + if args.fp8: + import torch + device_info = torch.cuda.get_device_properties(0) + version = device_info.major * 10 + device_info.minor + if version < 90: # FP8 is only supppoted on Hopper. + raise ValueError(f"Cannot apply FP8 quantization for GPU with compute capability {version / 10.0}. FP8 is only supppoted on Hopper.") + args.optimization_level = 4 + print(f"[I] The default optimization level has been set to {args.optimization_level} for FP8.") + + if args.quantization_level == 0.0 and args.fp8: + args.quantization_level = 3.0 + print("[I] The default quantization level has been set to 3.0 for FP8.") + kwargs_init_pipeline = { 'version': args.version, 'max_batch_size': max_batch_size, @@ -89,6 +102,9 @@ def process_pipeline_args(args): 'enable_all_tactics': args.build_all_tactics, 'enable_refit': args.build_enable_refit, 'timing_cache': args.timing_cache, + 'fp8': args.fp8, + 'quantization_level': args.quantization_level, + } args_run_demo = (input_image, args.height, args.width, args.batch_size, args.batch_count, args.num_warmup_runs, args.use_cuda_graph) @@ -99,7 +115,6 @@ def process_pipeline_args(args): print("[I] Initializing StableDiffusion img2vid demo using TensorRT") args = parseArgs() kwargs_init_pipeline, kwargs_load_engine, args_run_demo = process_pipeline_args(args) - # Initialize demo demo = StableVideoDiffusionPipeline( pipeline_type=PIPELINE_TYPE.IMG2VID, diff --git a/demo/Diffusion/demo_txt2img_flux.py b/demo/Diffusion/demo_txt2img_flux.py index 067c2c8fe..501a6a05e 100644 --- a/demo/Diffusion/demo_txt2img_flux.py +++ b/demo/Diffusion/demo_txt2img_flux.py @@ -68,6 +68,17 @@ def parse_args(): default=512, help="Maximum sequence length to use with the prompt", ) + parser.add_argument( + "--bf16", + action='store_true', + help="Run pipeline in BFloat16 precision" + ) + parser.add_argument( + "--low-vram", + action='store_true', + help="Optimize for low VRAM usage, possibly at the expense of inference performance. Disabled by default." + ) + return parser.parse_args() @@ -118,8 +129,9 @@ def process_demo_args(args): demo = FluxPipeline( pipeline_type=PIPELINE_TYPE.TXT2IMG, max_sequence_length=args.max_sequence_length, - **kwargs_init_pipeline, - ) + bf16=args.bf16, + low_vram=args.low_vram, + **kwargs_init_pipeline) # Load TensorRT engines and pytorch modules demo.load_engines( diff --git a/demo/Diffusion/diffusion_pipeline.py b/demo/Diffusion/diffusion_pipeline.py index 02c7a6312..373323751 100644 --- a/demo/Diffusion/diffusion_pipeline.py +++ b/demo/Diffusion/diffusion_pipeline.py @@ -51,14 +51,18 @@ from typing import Optional, List from utils_modelopt import ( filter_func, + filter_func_no_proj_out, quantize_lvl, get_int8_config, check_lora, set_fmha, + set_quant_precision, generate_fp8_scales, + SD_FP8_BF16_DEFAULT_CONFIG, SD_FP8_FP16_DEFAULT_CONFIG, SD_FP8_FP32_DEFAULT_CONFIG, ) +import gc class DiffusionPipeline(ABC): """ @@ -101,7 +105,8 @@ def __init__( max_batch_size=16, denoising_steps=30, scheduler=None, - lora_scale: Optional[List[int]] = None, + lora_scale: float = 1.0, + lora_weight: Optional[List[float]] = None, lora_path: Optional[List[str]] = None, device='cuda', output_dir='.', @@ -129,7 +134,9 @@ def __init__( scheduler (str): The scheduler to guide the denoising process. Must be one of the values listed in DiffusionPipeline.SCHEDULER_DEFAULTS.values(). lora_scale (float): - Scale of LoRA weights, default 1 (must between 0 and 1). + Controls how much to influence the outputs with the LoRA parameters. (must between 0 and 1). + lora_weight (float): + The LoRA adapter(s) weights to use with the UNet. (must between 0 and 1). lora_path (str): Path to LoRA adaptor. Ex: 'latent-consistency/lcm-lora-sdv1-5'. device (str): @@ -204,16 +211,17 @@ def __init__( self.models = {} self.torch_models = {} self.engine = {} + self.shape_dicts = {} self.shared_device_memory = None # initialize lora loader and scales self.lora_loader = None - self.lora_scales = dict() + self.lora_weights = dict() if lora_path: - self.lora_loader = LoraLoader(lora_path) - assert len(lora_path) == len(lora_scale) + self.lora_loader = LoraLoader(lora_path, lora_weight, lora_scale) + assert len(lora_path) == len(lora_weight) for i, path in enumerate(lora_path): - self.lora_scales[path] = lora_scale[i] + self.lora_weights[path] = lora_weight[i] # initialized in load_resources() self.events = {} @@ -254,7 +262,9 @@ def load_resources(self, image_height, image_width, batch_size, seed): for model_name, obj in self.models.items(): if self.torch_fallback[model_name]: continue - self.engine[model_name].allocate_buffers(shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.device) + self.shape_dicts[model_name] = obj.get_shape_dict(batch_size, image_height, image_width) + if not self.low_vram: + self.engine[model_name].allocate_buffers(shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.device) def _create_directories(self, engine_dir, onnx_dir): # Create directories if missing @@ -297,7 +307,7 @@ def _initialize_models(self): def _get_lora_suffix(self): if self.lora_loader: - return '-' + '-'.join([str(md5(path.encode('utf-8')).hexdigest()) + '-' + ('%.2f' % self.lora_scales[path]) for path in sorted(self.lora_loader.paths)]) + return '-' + '-'.join([str(md5(path.encode('utf-8')).hexdigest()) + '-' + ('%.2f' % self.lora_weights[path]) + '-' + ('%.2f' % self.lora_loader.scale) for path in sorted(self.lora_loader.paths)]) return '' def _prepare_model_configs(self, onnx_dir, engine_dir, enable_refit, int8, fp8, quantization_level, quantization_percentile, quantization_alpha, calibration_size): @@ -317,12 +327,15 @@ def _prepare_model_configs(self, onnx_dir, engine_dir, enable_refit, int8, fp8, if int8: assert self.pipeline_type.is_sd_xl_base() or self.version in ["1.5", "2.1", "2.1-base"], "int8 quantization only supported for SDXL, SD1.5 and SD2.1 pipeline" - if model_name == ('unetxl' if self.pipeline_type.is_sd_xl() else 'unet'): + if (self.pipeline_type.is_sd_xl() and model_name == 'unetxl') or \ + (model_name == 'unet'): config['use_int8'] = True config['model_suffix'] += f"-int8.l{quantization_level}.bs2.s{self.denoising_steps}.c{calibration_size}.p{quantization_percentile}.a{quantization_alpha}" elif fp8: - assert self.pipeline_type.is_sd_xl() or self.version in ["1.5", "2.1", "2.1-base"], "fp8 quantization only supported for SDXL, SD1.5 and SD2.1 pipeline" - if model_name == ('unetxl' if self.pipeline_type.is_sd_xl() else 'unet'): + assert self.pipeline_type.is_sd_xl() or self.version in ["1.5", "2.1", "2.1-base", "flux.1-dev"], "fp8 quantization only supported for SDXL, SD1.5, SD2.1 and FLUX pipeline" + if (self.pipeline_type.is_sd_xl() and model_name == 'unetxl') or \ + (self.version == "flux.1-dev" and model_name == 'transformer') or \ + (model_name == 'unet'): config['use_fp8'] = True config['model_suffix'] += f"-fp8.l{quantization_level}.bs2.s{self.denoising_steps}.c{calibration_size}.p{quantization_percentile}.a{quantization_alpha}" @@ -337,7 +350,7 @@ def _prepare_model_configs(self, onnx_dir, engine_dir, enable_refit, int8, fp8, return configs - def _calibrate_and_save_model(self, pipeline, model, model_config, quantization_level, quantization_percentile, quantization_alpha, calibration_size, calib_batch_size): + def _calibrate_and_save_model(self, pipeline, model, model_config, quantization_level, quantization_percentile, quantization_alpha, calibration_size, calib_batch_size, **kwargs): print(f"[I] Calibrated weights not found, generating {model_config['state_dict_path']}") calibration_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'calibration-prompts.txt') calibration_prompts = load_calib_prompts(calib_batch_size, calibration_file) @@ -347,22 +360,42 @@ def do_calibrate(pipeline, calibration_prompts, **kwargs): for i_th, prompts in enumerate(calibration_prompts): if i_th >= kwargs["calib_size"]: return - pipeline( - prompt=prompts, - num_inference_steps=kwargs["n_steps"], - negative_prompt=[ - "normal quality, low quality, worst quality, low res, blurry, nsfw, nude" - ] - * len(prompts), - ).images + if kwargs["model_id"] == "flux.1-dev": + + height = kwargs.get("height", 1024) + width = kwargs.get("width", 1024) + pipeline( + prompt=prompts, + prompt_2=prompts, + num_inference_steps=kwargs["n_steps"], + height=height, + width=width, + guidance_scale=3.5, + max_sequence_length=512 + ).images + else: + pipeline( + prompt=prompts, + num_inference_steps=kwargs["n_steps"], + negative_prompt=[ + "normal quality, low quality, worst quality, low res, blurry, nsfw, nude" + ] + * len(prompts), + ).images def forward_loop(model): - pipeline.unet = model + if self.version not in ["sd3", "flux.1-dev"]: + pipeline.unet = model + else: + pipeline.transformer = model + do_calibrate( pipeline=pipeline, calibration_prompts=calibration_prompts, calib_size=calibration_size // calib_batch_size, n_steps=self.denoising_steps, + model_id=self.version, + **kwargs ) print(f"[I] Performing calibration for {calibration_size} steps.") @@ -375,25 +408,36 @@ def forward_loop(model): self.denoising_steps ) elif model_config['use_fp8']: - quant_config = SD_FP8_FP32_DEFAULT_CONFIG if self.version == "2.1" else SD_FP8_FP16_DEFAULT_CONFIG + if self.version == "flux.1-dev": + quant_config = SD_FP8_BF16_DEFAULT_CONFIG + elif self.version == "2.1": + quant_config = SD_FP8_FP32_DEFAULT_CONFIG + else: + quant_config = SD_FP8_FP16_DEFAULT_CONFIG + check_lora(model) + if self.version == "flux.1-dev": + set_quant_precision(quant_config, "BFloat16") mtq.quantize(model, quant_config, forward_loop) mto.save(model, model_config['state_dict_path']) - def _get_quantized_model(self, obj, model_config, quantization_level, quantization_percentile, quantization_alpha, calibration_size, calib_batch_size): + def _get_quantized_model(self, obj, model_config, quantization_level, quantization_percentile, quantization_alpha, calibration_size, calib_batch_size, **kwargs): pipeline = obj.get_pipeline() - model = pipeline.unet + model = pipeline.unet if self.version not in ["sd3", "flux.1-dev"] else pipeline.transformer if model_config['use_fp8'] and quantization_level == 4.0: set_fmha(model) if not os.path.exists(model_config['state_dict_path']): - self._calibrate_and_save_model(pipeline, model, model_config, quantization_level, quantization_percentile, quantization_alpha, calibration_size, calib_batch_size) + self._calibrate_and_save_model(pipeline, model, model_config, quantization_level, quantization_percentile, quantization_alpha, calibration_size, calib_batch_size, **kwargs) else: mto.restore(model, model_config['state_dict_path']) if not os.path.exists(model_config['onnx_path']): quantize_lvl(model, quantization_level) - mtq.disable_quantizer(model, filter_func) + if self.version in ["flux.1-dev"]: + mtq.disable_quantizer(model, filter_func_no_proj_out) + else: + mtq.disable_quantizer(model, filter_func) if model_config['use_fp8']: generate_fp8_scales(model) else: @@ -407,10 +451,10 @@ def _export_onnx(self, obj, model_config, opt_image_height, opt_image_width, sta if do_export_onnx or do_export_weights_map: if not model_config['use_int8'] and not model_config['use_fp8']: - obj.export_onnx(model_config['onnx_path'], model_config['onnx_opt_path'], onnx_opset, opt_image_height, opt_image_width, enable_lora_merge=model_config['do_lora_merge'], static_shape=static_shape) + obj.export_onnx(model_config['onnx_path'], model_config['onnx_opt_path'], onnx_opset, opt_image_height, opt_image_width, enable_lora_merge=model_config['do_lora_merge'], static_shape=static_shape, lora_loader=self.lora_loader) else: print(f"[I] Generating quantized ONNX model: {model_config['onnx_path']}") - quantized_model = self._get_quantized_model(obj, model_config, quantization_level, quantization_percentile, quantization_alpha, calibration_size, calib_batch_size) + quantized_model = self._get_quantized_model(obj, model_config, quantization_level, quantization_percentile, quantization_alpha, calibration_size, calib_batch_size, height=opt_image_width, width=opt_image_width) obj.export_onnx(model_config['onnx_path'], model_config['onnx_opt_path'], onnx_opset, opt_image_height, opt_image_width, custom_model=quantized_model, static_shape=static_shape) # FIXME do_export_weights_map needs ONNX graph @@ -453,14 +497,14 @@ def _refit_engine(self, obj, model_name, model_config): if not os.path.exists(model_config['refit_weights_path']): print(f"[I] Saving refit weights: {model_config['refit_weights_path']}") - model = merge_loras(obj.get_model(), obj.lora_dict, obj.lora_alphas, obj.lora_scales) - refit_weights = get_refit_weights(model.state_dict(), model_config['onnx_opt_path'], weights_name_mapping, weights_shape_mapping) - torch.save(refit_weights, model_config['refit_weights_path']) + model = merge_loras(obj.get_model(), self.lora_loader) + refit_weights, updated_weight_names = get_refit_weights(model.state_dict(), model_config['onnx_opt_path'], weights_name_mapping, weights_shape_mapping) + torch.save((refit_weights, updated_weight_names), model_config['refit_weights_path']) unload_model(model) else: print(f"[I] Loading refit weights: {model_config['refit_weights_path']}") - refit_weights = torch.load(model_config['refit_weights_path']) - self.engine[model_name].refit(refit_weights, obj.fp16) + refit_weights, updated_weight_names = torch.load(model_config['refit_weights_path']) + self.engine[model_name].refit(refit_weights, updated_weight_names) def _load_torch_models(self): # Load torch models @@ -551,6 +595,10 @@ def load_engines( continue self._export_onnx(obj, model_configs[model_name], opt_image_height, opt_image_width, static_shape, onnx_opset, quantization_level, quantization_percentile, quantization_alpha, calibration_size, calib_batch_size) + # Release temp GPU memory during onnx export to avoid OOM. + gc.collect() + torch.cuda.empty_cache() + # Build TensorRT engines for model_name, obj in self.models.items(): if self.torch_fallback[model_name]: @@ -566,11 +614,21 @@ def load_engines( for model_name, obj in self.models.items(): if self.torch_fallback[model_name]: continue - self.engine[model_name].load() - model_config = model_configs[model_name] - if model_config['do_engine_refit'] and obj.lora_dict: + + # For non low_vram case, the engines will remain in GPU memory from now on. + assert self.engine[model_name].engine is None + if not self.low_vram: + self.engine[model_name].load() + + if model_config['do_engine_refit'] and self.lora_loader: + # For low_vram, using on-demand load and unload for refit. + if self.low_vram: + assert self.engine[model_name].engine is None + self.engine[model_name].load() self._refit_engine(obj, model_name, model_config) + if self.low_vram: + self.engine[model_name].unload() # Load PyTorch models if torch-inference mode is enabled self._load_torch_models() @@ -581,7 +639,11 @@ def load_engines( def calculate_max_device_memory(self): max_device_memory = 0 for model_name, engine in self.engine.items(): + if self.low_vram: + engine.load() max_device_memory = max(max_device_memory, engine.engine.device_memory_size) + if self.low_vram: + engine.unload() return max_device_memory def activate_engines(self, shared_device_memory=None): @@ -590,11 +652,15 @@ def activate_engines(self, shared_device_memory=None): _, shared_device_memory = cudart.cudaMalloc(max_device_memory) self.shared_device_memory = shared_device_memory # Load and activate TensorRT engines - for engine in self.engine.values(): - engine.activate(device_memory=self.shared_device_memory) + if not self.low_vram: + for engine in self.engine.values(): + engine.activate(device_memory=self.shared_device_memory) def run_engine(self, model_name, feed_dict): engine = self.engine[model_name] + # CUDA graphs should be disabled when low_vram is enabled. + if self.low_vram: + assert self.use_cuda_graph == False return engine.infer(feed_dict, self.stream, use_cuda_graph=self.use_cuda_graph) def teardown(self): @@ -645,4 +711,3 @@ def infer(self): def run(self): """Run the pipeline.""" raise NotImplementedError("Please Implement the run method") - diff --git a/demo/Diffusion/flux_pipeline.py b/demo/Diffusion/flux_pipeline.py index 4de8e0f7a..fb764d68d 100644 --- a/demo/Diffusion/flux_pipeline.py +++ b/demo/Diffusion/flux_pipeline.py @@ -61,7 +61,9 @@ def __init__( pipeline_type=PIPELINE_TYPE.TXT2IMG, guidance_scale=3.5, max_sequence_length=512, - **kwargs, + bf16=False, + low_vram=False, + **kwargs ): """ Initializes the Flux pipeline. @@ -72,10 +74,14 @@ def __init__( Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality. max_sequence_length (`int`, defaults to 512): Maximum sequence length to use with the `prompt`. + bf16 (`bool`, defaults to False): + Whether to run the pipeline in BFloat16 precision. """ super().__init__(version=version, pipeline_type=pipeline_type, **kwargs) self.guidance_scale = guidance_scale self.max_sequence_length = max_sequence_length + self.bf16=bf16 + self.low_vram = low_vram # Pipeline type self.stages = ["clip", "t5", "transformer", "vae"] @@ -105,13 +111,15 @@ def _initialize_models(self, framework_model_dir, int8, fp8): "max_batch_size": self.max_batch_size, } - self.fp16 = True + self.bf16 = True if int8 or fp8 else self.bf16 + self.fp16 = True if not self.bf16 else False self.tf32 = True if "clip" in self.stages: self.models["clip"] = CLIPModel( **models_args, fp16=self.fp16, tf32=self.tf32, + bf16=self.bf16, embedding_dim=get_clip_embedding_dim(self.version, self.pipeline_type), keep_pooled_output=True, subfolder="text_encoder", @@ -123,6 +131,7 @@ def _initialize_models(self, framework_model_dir, int8, fp8): **models_args, fp16=False, tf32=self.tf32, + bf16=self.bf16, subfolder="text_encoder_2", text_maxlen=self.max_sequence_length, ) @@ -130,15 +139,18 @@ def _initialize_models(self, framework_model_dir, int8, fp8): if "transformer" in self.stages: self.models["transformer"] = FluxTransformerModel( **models_args, - fp16=self.fp16, + bf16=True if int8 or fp8 else self.bf16, + fp16=False if int8 or fp8 else self.fp16, + int8=int8, + fp8=fp8, tf32=self.tf32, text_maxlen=self.max_sequence_length, - build_strongly_typed=False, + build_strongly_typed=True, ) if "vae" in self.stages: # Accuracy issues with FP16 - self.models["vae"] = VAEModel(**models_args, fp16=False, tf32=self.tf32) + self.models["vae"] = VAEModel(**models_args, fp16=False, tf32=self.tf32, bf16=self.bf16) self.vae_scale_factor = ( 2 ** (len(self.models["vae"].config["block_out_channels"])) @@ -297,9 +309,7 @@ def tokenize(prompt, max_sequence_length): text_encoder_output = tokenize(prompt, max_sequence_length) self.profile_stop(encoder) - return ( - text_encoder_output.to(torch.float16) if self.fp16 else text_encoder_output - ) + return text_encoder_output.to(torch.float16) if self.fp16 else text_encoder_output.to(torch.bfloat16) if self.bf16 else text_encoder_output def denoise_latent( self, @@ -347,7 +357,7 @@ def denoise_latent( )[0] self.profile_stop(denoiser) - return latents.to(dtype=torch.float32) + return latents.to(dtype=torch.bfloat16) if self.bf16 else latents.to(dtype=torch.float32) def decode_latent(self, latents, decoder="vae"): self.profile_start(decoder, color="red") @@ -445,17 +455,40 @@ def infer( // 4, latent_height=latent_height, latent_width=latent_width, - latents_dtype=torch.float16 if self.fp16 else torch.float32, - ) + latents_dtype=torch.float16 if self.fp16 else torch.bfloat16 if self.bf16 else torch.float32) + + class LoadModelContext: + def __init__(ctx, model_names, low_vram=False): + ctx.model_names = model_names + ctx.low_vram = low_vram + def __enter__(ctx): + if not ctx.low_vram: + return + for model_name in ctx.model_names: + # creating engine object (load from plan file) + self.engine[model_name].load() + # creating context + self.engine[model_name].activate(device_memory=self.shared_device_memory) + # creating input and output buffer + self.engine[model_name].allocate_buffers(shape_dict=self.shape_dicts[model_name], device=self.device) + def __exit__(ctx, exc_type, exc_val, exc_tb): + if not ctx.low_vram: + return + for model_name in ctx.model_names: + self.engine[model_name].deallocate_buffers() + self.engine[model_name].deactivate() + self.engine[model_name].unload() # CLIP and T5 text encoder(s) - pooled_embeddings = self.encode_prompt(prompt, pooled_output=True) - text_embeddings = self.encode_prompt( - prompt2, encoder="t5", max_sequence_length=self.max_sequence_length - ) - text_ids = torch.zeros(text_embeddings.shape[1], 3).to( - device=self.device, dtype=text_embeddings.dtype - ) + + with LoadModelContext(["clip","t5"], low_vram=self.low_vram): + pooled_embeddings = self.encode_prompt(prompt, pooled_output=True) + text_embeddings = self.encode_prompt( + prompt2, encoder="t5", max_sequence_length=self.max_sequence_length + ) + text_ids = torch.zeros(text_embeddings.shape[1], 3).to( + device=self.device, dtype=text_embeddings.dtype + ) # Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) @@ -486,23 +519,25 @@ def infer( num_inference_steps = len(timesteps) # DiT denoiser - latents = self.denoise_latent( - latents, - timesteps, - text_embeddings, - pooled_embeddings, - text_ids, - latent_image_ids, - ) + with LoadModelContext(["transformer"], low_vram=self.low_vram): + latents = self.denoise_latent( + latents, + timesteps, + text_embeddings, + pooled_embeddings, + text_ids, + latent_image_ids, + ) # VAE decode latent - latents = self._unpack_latents( - latents, image_height, image_width, self.vae_scale_factor - ) - latents = ( - latents / self.models["vae"].config["scaling_factor"] - ) + self.models["vae"].config["shift_factor"] - images = self.decode_latent(latents) + with LoadModelContext(["vae"], low_vram=self.low_vram): + latents = self._unpack_latents( + latents, image_height, image_width, self.vae_scale_factor + ) + latents = ( + latents / self.models["vae"].config["scaling_factor"] + ) + self.models["vae"].config["shift_factor"] + images = self.decode_latent(latents) torch.cuda.synchronize() e2e_toc = time.perf_counter() @@ -539,6 +574,9 @@ def run( use_cuda_graph, **kwargs, ): + if self.low_vram and self.use_cuda_graph: + print("[W] Using low_vram, use_cuda_graph will be disabled") + self.use_cuda_graph = False num_warmup_runs = max(1, num_warmup_runs) if use_cuda_graph else num_warmup_runs if num_warmup_runs > 0: print("[I] Warming up ..") diff --git a/demo/Diffusion/models.py b/demo/Diffusion/models.py index 241ca8214..1f9a71edb 100755 --- a/demo/Diffusion/models.py +++ b/demo/Diffusion/models.py @@ -16,7 +16,7 @@ # from diffusers import DiffusionPipeline -from diffusers.loaders import LoraLoaderMixin +from diffusers.loaders import StableDiffusionLoraLoaderMixin from diffusers.pipelines.wuerstchen import PaellaVQModel import json import numpy as np @@ -186,7 +186,7 @@ def fuse_mha_qkv_int8_sq(self): print(f"Removed {removed} QDQ nodes") return removed # expected 72 for L2.5 - def modify_fp8_graph(self): + def modify_fp8_graph(self, is_fp16_io=True): onnx_graph = gs.export_onnx(self.graph) # Convert INT8 Zero to FP8. onnx_graph = convert_zp_fp8(onnx_graph) @@ -196,7 +196,8 @@ def modify_fp8_graph(self): # Add cast nodes to Resize I/O. cast_resize_io(self.graph) # Convert model inputs and outputs to fp16 I/O. - convert_fp16_io(self.graph) + if is_fp16_io: + convert_fp16_io(self.graph) # Add cast nodes to MHA's BMM1 and BMM2's I/O. cast_fp8_mha_io(self.graph) @@ -289,52 +290,15 @@ def optimize_checkpoint(model, torch_inference): assert torch_inference in torch_inference_modes return torch.compile(model, mode=torch_inference, dynamic=False, fullgraph=False) -class LoraLoader(LoraLoaderMixin): +class LoraLoader(StableDiffusionLoraLoaderMixin): def __init__(self, paths, + weights, + scale ): self.paths = paths - self.state_dict = dict() - self.network_alphas = dict() - - for path in paths: - state_dict, network_alphas = self.lora_state_dict(path) - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.state_dict[path] = state_dict - self.network_alphas[path] = network_alphas - - def get_dicts(self, - prefix='unet', - convert_to_diffusers=False, - ): - state_dict = dict() - network_alphas = dict() - - for path in self.paths: - keys = list(self.state_dict[path].keys()) - if all(key.startswith(('unet', 'text_encoder')) for key in keys): - keys = [k for k in keys if k.startswith(prefix)] - if keys: - print(f"Processing {prefix} LoRA: {path}") - state_dict[path] = {k.replace(f"{prefix}.", ""): v for k, v in self.state_dict[path].items() if k in keys} - - network_alphas[path] = None - if path in self.network_alphas and self.network_alphas[path] is not None: - alpha_keys = [k for k in self.network_alphas[path].keys() if k.startswith(prefix)] - network_alphas[path] = { - k.replace(f"{prefix}.", ""): v for k, v in self.network_alphas[path].items() if k in alpha_keys - } - - else: - # Otherwise, we're dealing with the old format. - warn_message = "You have saved the LoRA weights using the old format. To convert LoRA weights to the new format, first load them in a dictionary and then create a new dictionary as follows: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." - print(warn_message) - - return state_dict, network_alphas - + self.weights = weights + self.scale = scale class BaseModel(): def __init__(self, @@ -383,12 +347,11 @@ def __init__(self, self.embedding_dim = embedding_dim self.extra_output_names = [] - self.lora_dict = None self.do_constant_folding = True def get_pipeline(self): model_opts = {'variant': 'fp16', 'torch_dtype': torch.float16} if self.fp16 else {} - model_opts = {'variant': 'bf16', 'torch_dtype': torch.bfloat16} if self.bf16 else model_opts + model_opts = {'torch_dtype': torch.bfloat16} if self.bf16 else model_opts return DiffusionPipeline.from_pretrained( self.path, use_safetensors=self.hf_safetensor, @@ -434,6 +397,7 @@ def export_onnx( custom_model=None, enable_lora_merge=False, static_shape=False, + lora_loader=None ): onnx_opt_graph = None # Export optimized ONNX model (if missing) @@ -442,7 +406,8 @@ def export_onnx( print(f"[I] Exporting ONNX model: {onnx_path}") def export_onnx(model): if enable_lora_merge: - model = merge_loras(model, self.lora_dict, self.lora_alphas, self.lora_scales) + assert lora_loader is not None + model = merge_loras(model, lora_loader) inputs = self.get_sample_input(1, opt_image_height, opt_image_width, static_shape) torch.onnx.export(model, inputs, @@ -530,16 +495,16 @@ def optimize(self, onnx_graph, return_onnx=True, **kwargs): opt.cleanup() opt.info(self.name + ': cleanup') if kwargs.get('modify_fp8_graph', False): - opt.modify_fp8_graph() + is_fp16_io = kwargs.get('is_fp16_io', True) + opt.modify_fp8_graph(is_fp16_io=is_fp16_io) opt.info(self.name + ': modify fp8 graph') - else: - opt.fold_constants() - opt.info(self.name + ': fold constants') - opt.infer_shapes() - opt.info(self.name + ': shape inference') - if kwargs.get('fuse_mha_qkv_int8', False): - opt.fuse_mha_qkv_int8_sq() - opt.info(self.name + ': fuse QKV nodes') + opt.fold_constants() + opt.info(self.name + ': fold constants') + opt.infer_shapes() + opt.info(self.name + ': shape inference') + if kwargs.get('fuse_mha_qkv_int8', False): + opt.fuse_mha_qkv_int8_sq() + opt.info(self.name + ': fuse QKV nodes') onnx_opt_graph = opt.cleanup(return_onnx=return_onnx) opt.info(self.name + ': finished') return onnx_opt_graph @@ -584,8 +549,6 @@ def __init__(self, output_hidden_states=False, keep_pooled_output=False, subfolder="text_encoder", - lora_dict=None, - lora_alphas=None, ): super(CLIPModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, tf32=tf32, bf16=bf16, max_batch_size=max_batch_size, embedding_dim=embedding_dim) self.subfolder = subfolder @@ -597,7 +560,7 @@ def __init__(self, self.extra_output_names = ['hidden_states'] def get_model(self, torch_inference=''): - model_opts = {'torch_dtype': torch.float16} if self.fp16 else {} + model_opts = {'torch_dtype': torch.float16} if self.fp16 else {'torch_dtype': torch.bfloat16} if self.bf16 else {} clip_model_dir = get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, self.subfolder) if not os.path.exists(clip_model_dir): model = CLIPTextModel.from_pretrained(self.path, @@ -686,8 +649,6 @@ def __init__(self, max_batch_size=16, output_hidden_states=False, subfolder="text_encoder_2", - lora_dict=None, - lora_alphas=None, ): super(CLIPWithProjModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, bf16=bf16, max_batch_size=max_batch_size, embedding_dim=get_clipwithproj_embedding_dim(version, pipeline), output_hidden_states=output_hidden_states) @@ -769,7 +730,7 @@ def __init__(self, self.config = AutoConfig.from_pretrained(self.path, subfolder=self.subfolder, token=self.hf_token) def get_model(self, torch_inference=''): - model_opts = {'torch_dtype': torch.float16} if self.fp16 else {} + model_opts = {'torch_dtype': torch.float16} if self.fp16 else {'torch_dtype': torch.bfloat16} if self.bf16 else {} t5_model_dir = get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, self.subfolder) if not os.path.exists(t5_model_dir): model = T5EncoderModel.from_pretrained(self.path, @@ -1083,9 +1044,6 @@ def __init__(self, max_batch_size = 16, text_maxlen = 77, controlnets = None, - lora_scales = None, - lora_dict = None, - lora_alphas = None, do_classifier_free_guidance = False, ): @@ -1093,9 +1051,6 @@ def __init__(self, self.subfolder = 'unet' self.controlnets = get_path(version, pipeline, controlnets) if controlnets else None self.unet_dim = (9 if pipeline.is_inpaint() else 4) - self.lora_scales = lora_scales - self.lora_dict = lora_dict - self.lora_alphas = lora_alphas self.xB = 2 if do_classifier_free_guidance else 1 # batch multiplier def get_model(self, torch_inference=''): @@ -1241,18 +1196,12 @@ def __init__(self, fp8 = False, max_batch_size = 16, text_maxlen = 77, - lora_scales = None, - lora_dict = None, - lora_alphas = None, do_classifier_free_guidance = False, ): super(UNetXLModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, int8=int8, fp8=fp8, max_batch_size=max_batch_size, text_maxlen=text_maxlen, embedding_dim=get_unet_embedding_dim(version, pipeline)) self.subfolder = 'unet' self.unet_dim = (9 if pipeline.is_inpaint() else 4) self.time_dim = (5 if pipeline.is_sd_xl_refiner() else 6) - self.lora_scales = lora_scales - self.lora_dict = lora_dict - self.lora_alphas = lora_alphas self.xB = 2 if do_classifier_free_guidance else 1 # batch multiplier def get_model(self, torch_inference=''): @@ -1498,7 +1447,7 @@ def get_shape_dict(self, batch_size, image_height, image_width): 'added_time_ids': (self.xB*batch_size, 3), } - def get_sample_input(self, batch_size, image_height, image_width): + def get_sample_input(self, batch_size, image_height, image_width, static_shape): # TODO chunk_size if forward_chunking is used latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) @@ -1668,18 +1617,18 @@ def __init__(self, tf32=False, int8 = False, fp8 = False, + bf16 = False, max_batch_size = 16, text_maxlen = 77, build_strongly_typed=False ): - - super(FluxTransformerModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, tf32=tf32, int8=int8, fp8=fp8, max_batch_size=max_batch_size, text_maxlen=text_maxlen) + super(FluxTransformerModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, tf32=tf32, int8=int8, fp8=fp8, bf16=bf16, max_batch_size=max_batch_size, text_maxlen=text_maxlen) self.subfolder = 'transformer' self.config = FluxTransformer2DModel.load_config(self.path, subfolder=self.subfolder, token=self.hf_token) self.build_strongly_typed = build_strongly_typed def get_model(self, torch_inference=''): - model_opts = {'torch_dtype': torch.float16} if self.fp16 else {} + model_opts = {'torch_dtype': torch.float16} if self.fp16 else {'torch_dtype': torch.bfloat16} if self.bf16 else {} transformer_model_dir = get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, self.subfolder) transformer_path = self.get_model_path(transformer_model_dir, model_opts) if not os.path.exists(transformer_path): @@ -1698,7 +1647,7 @@ def get_model(self, torch_inference=''): return model def get_input_names(self): - return ['hidden_states', 'encoder_hidden_states', 'pooled_projections', 'timestep', 'img_ids', 'txt_ids', 'guidance'] + return ['hidden_states', 'encoder_hidden_states', 'pooled_projections', 'timestep', 'img_ids', 'txt_ids', 'guidance'] def get_output_names(self): return ['latent'] @@ -1743,18 +1692,28 @@ def get_shape_dict(self, batch_size, image_height, image_width): def get_sample_input(self, batch_size, image_height, image_width, static_shape): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - dtype = torch.float16 if self.fp16 else torch.float32 + dtype = torch.float32 + assert not (self.fp16 and self.bf16), "fp16 and bf16 cannot be enabled simultaneously" + tensor_dtype = torch.bfloat16 if self.bf16 else (torch.float16 if self.fp16 else torch.float32) + return ( - torch.randn(batch_size, (latent_height // 2) * (latent_width // 2), self.config['in_channels'], dtype=dtype, device=self.device), - torch.randn(batch_size, self.text_maxlen, self.config['joint_attention_dim'], dtype=dtype, device=self.device), - torch.randn(batch_size, self.config['pooled_projection_dim'], dtype=dtype, device=self.device), - torch.tensor([1.]*batch_size, dtype=dtype, device=self.device), + torch.randn(batch_size, (latent_height // 2) * (latent_width // 2), self.config['in_channels'], dtype=tensor_dtype, device=self.device), + torch.randn(batch_size, self.text_maxlen, self.config['joint_attention_dim'], dtype=tensor_dtype, device=self.device), + torch.randn(batch_size, self.config['pooled_projection_dim'], dtype=tensor_dtype, device=self.device), + torch.tensor([1.]*batch_size, dtype=tensor_dtype, device=self.device), torch.randn((latent_height // 2) * (latent_width // 2), 3, dtype=dtype, device=self.device), torch.randn(self.text_maxlen, 3, dtype=dtype, device=self.device), { 'guidance': torch.tensor([1.]*batch_size, dtype=dtype, device=self.device), } ) + def optimize(self, onnx_graph): + if self.fp8: + return super().optimize(onnx_graph, modify_fp8_graph=True, is_fp16_io=False) + if self.int8: + return super().optimize(onnx_graph, fuse_mha_qkv_int8=True) + return super().optimize(onnx_graph) + class VAEModel(BaseModel): def __init__(self, @@ -1766,23 +1725,26 @@ def __init__(self, framework_model_dir, fp16=False, tf32=False, + bf16=False, max_batch_size=16, ): - super(VAEModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, tf32=tf32, max_batch_size=max_batch_size) + super(VAEModel, self).__init__(version, pipeline, device=device, hf_token=hf_token, verbose=verbose, framework_model_dir=framework_model_dir, fp16=fp16, tf32=tf32, bf16=bf16, max_batch_size=max_batch_size) self.subfolder = 'vae' self.config = AutoencoderKL.load_config(self.path, subfolder=self.subfolder, token=self.hf_token) def get_model(self, torch_inference=''): + model_opts = {'torch_dtype': torch.float16} if self.fp16 else {'torch_dtype': torch.bfloat16} if self.bf16 else {} vae_decoder_model_path = get_checkpoint_dir(self.framework_model_dir, self.version, self.pipeline, self.subfolder) if not os.path.exists(vae_decoder_model_path): model = AutoencoderKL.from_pretrained(self.path, subfolder=self.subfolder, use_safetensors=self.hf_safetensor, - token=self.hf_token).to(self.device) - model.save_pretrained(vae_decoder_model_path) + token=self.hf_token, + **model_opts).to(self.device) + model.save_pretrained(vae_decoder_model_path, **model_opts) else: print(f"[I] Load AutoencoderKL (decoder) model from: {vae_decoder_model_path}") - model = AutoencoderKL.from_pretrained(vae_decoder_model_path).to(self.device) + model = AutoencoderKL.from_pretrained(vae_decoder_model_path, **model_opts).to(self.device) model.forward = model.decode model = optimize_checkpoint(model, torch_inference) return model @@ -1819,7 +1781,8 @@ def get_shape_dict(self, batch_size, image_height, image_width): def get_sample_input(self, batch_size, image_height, image_width, static_shape): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return torch.randn(batch_size, self.config['latent_channels'], latent_height, latent_width, dtype=torch.float32, device=self.device) + dtype = torch.float16 if self.fp16 else torch.bfloat16 if self.bf16 else torch.float32 + return torch.randn(batch_size, self.config['latent_channels'], latent_height, latent_width, dtype=dtype, device=self.device) class SD3_VAEDecoderModel(BaseModel): def __init__(self, diff --git a/demo/Diffusion/requirements.txt b/demo/Diffusion/requirements.txt index 4e494185b..2316b8789 100755 --- a/demo/Diffusion/requirements.txt +++ b/demo/Diffusion/requirements.txt @@ -7,13 +7,14 @@ git+https://github.com/huggingface/diffusers.git # Install from source for the l ftfy matplotlib nvtx -onnx==1.15.0 -onnxruntime==1.17.3 +onnx==1.17.0 +onnxruntime==1.19.2 opencv-python==4.8.0.74 scipy transformers==4.42.2 --extra-index-url https://pypi.nvidia.com -nvidia-modelopt[torch,onnx]==0.15.1 +nvidia-modelopt[torch,onnx]==0.19.0 onnx-graphsurgeon +peft==0.13.0 polygraphy==0.49.9 sentencepiece diff --git a/demo/Diffusion/stable_diffusion_pipeline.py b/demo/Diffusion/stable_diffusion_pipeline.py index 332b55723..cc64e70ea 100644 --- a/demo/Diffusion/stable_diffusion_pipeline.py +++ b/demo/Diffusion/stable_diffusion_pipeline.py @@ -32,7 +32,6 @@ import inspect from models import ( get_clip_embedding_dim, - get_path, LoraLoader, make_tokenizer, CLIPModel, @@ -108,7 +107,8 @@ def __init__( vae_scaling_factor=0.18215, framework_model_dir='pytorch_model', controlnets=None, - lora_scale: Optional[List[int]] = None, + lora_scale: float = 1.0, + lora_weight: Optional[List[float]] = None, lora_path: Optional[List[str]] = None, return_latents=False, torch_inference='', @@ -239,12 +239,12 @@ def __init__( # initialize lora loader and scales self.lora_loader = None - self.lora_scales = dict() + self.lora_weights = dict() if lora_path: - self.lora_loader = LoraLoader(lora_path) - assert len(lora_path) == len(lora_scale) + self.lora_loader = LoraLoader(lora_path, lora_weight, lora_scale) + assert len(lora_path) == len(lora_weight) for i, path in enumerate(lora_path): - self.lora_scales[path] = lora_scale[i] + self.lora_weights[path] = lora_weight[i] # initialized in loadResources() self.events = {} @@ -335,20 +335,11 @@ def initializeModels(self, framework_model_dir, int8, fp8): subfolder = 'text_encoder_2' self.models['clip2'] = CLIPWithProjModel(**models_args, fp16=True, output_hidden_states=self.config.get('clip_hidden_states', False), subfolder=subfolder) - lora_dict, lora_alphas = (None, None) if 'unet' in self.stages: - if self.lora_loader: - lora_dict, lora_alphas = self.lora_loader.get_dicts('unet') - assert len(lora_dict) == len(self.lora_scales) - self.models['unet'] = UNetModel(**models_args, fp16=True, int8=int8, fp8=fp8, controlnets=self.controlnets, - lora_scales=self.lora_scales, lora_dict=lora_dict, lora_alphas=lora_alphas, do_classifier_free_guidance=self.do_classifier_free_guidance) + self.models['unet'] = UNetModel(**models_args, fp16=True, int8=int8, fp8=fp8, controlnets=self.controlnets, do_classifier_free_guidance=self.do_classifier_free_guidance) if 'unetxl' in self.stages: - if not self.pipeline_type.is_sd_xl_refiner() and self.lora_loader: - lora_dict, lora_alphas = self.lora_loader.get_dicts('unet') - assert len(lora_dict) == len(self.lora_scales) - self.models['unetxl'] = UNetXLModel(**models_args, fp16=True, int8=int8, fp8=fp8, - lora_scales=self.lora_scales, lora_dict=lora_dict, lora_alphas=lora_alphas, do_classifier_free_guidance=self.do_classifier_free_guidance) + self.models['unetxl'] = UNetXLModel(**models_args, fp16=True, int8=int8, fp8=fp8, do_classifier_free_guidance=self.do_classifier_free_guidance) vae_fp16 = not self.pipeline_type.is_sd_xl() @@ -441,7 +432,7 @@ def loadEngines( # Configure pipeline models to load model_names = self.models.keys() - lora_suffix = '-'+'-'.join([str(md5(path.encode('utf-8')).hexdigest())+'-'+('%.2f' % self.lora_scales[path]) for path in sorted(self.lora_loader.paths)]) if self.lora_loader else '' + lora_suffix = '-'+'-'.join([str(md5(path.encode('utf-8')).hexdigest())+'-'+('%.2f' % self.lora_weights[path])+'-'+('%.2f' % self.lora_loader.scale) for path in sorted(self.lora_loader.paths)]) if self.lora_loader else '' # Enable refit and LoRA merging only for UNet & UNetXL for now do_engine_refit = dict(zip(model_names, [not self.pipeline_type.is_sd_xl_refiner() and enable_refit and model_name.startswith('unet') for model_name in model_names])) do_lora_merge = dict(zip(model_names, [not enable_refit and self.lora_loader and model_name.startswith('unet') for model_name in model_names])) @@ -474,7 +465,7 @@ def loadEngines( if do_export_onnx or do_export_weights_map: # Non-quantized ONNX export if not use_int8[model_name] and not use_fp8[model_name]: - obj.export_onnx(onnx_path[model_name], onnx_opt_path[model_name], onnx_opset, opt_image_height, opt_image_width, enable_lora_merge=do_lora_merge[model_name], static_shape=static_shape) + obj.export_onnx(onnx_path[model_name], onnx_opt_path[model_name], onnx_opset, opt_image_height, opt_image_width, enable_lora_merge=do_lora_merge[model_name], static_shape=static_shape, lora_loader=self.lora_loader) else: pipeline = obj.get_pipeline() model = pipeline.unet @@ -576,7 +567,7 @@ def forward_loop(model): if torch_fallback[model_name]: continue self.engine[model_name].load() - if do_engine_refit[model_name] and obj.lora_dict: + if do_engine_refit[model_name] and self.lora_loader: assert weights_map_path[model_name] with open(weights_map_path[model_name], 'r') as fp_wts: print(f"[I] Loading weights map: {weights_map_path[model_name]} ") @@ -584,14 +575,14 @@ def forward_loop(model): refit_weights_path = self.getRefitNodesPath(model_name, engine_dir, suffix=lora_suffix) if not os.path.exists(refit_weights_path): print(f"[I] Saving refit weights: {refit_weights_path}") - model = merge_loras(obj.get_model(), obj.lora_dict, obj.lora_alphas, obj.lora_scales) - refit_weights = get_refit_weights(model.state_dict(), onnx_opt_path[model_name], weights_name_mapping, weights_shape_mapping) - torch.save(refit_weights, refit_weights_path) + model = merge_loras(obj.get_model(), self.lora_loader) + refit_weights, updated_weight_names = get_refit_weights(model.state_dict(), onnx_opt_path[model_name], weights_name_mapping, weights_shape_mapping) + torch.save((refit_weights, updated_weight_names), refit_weights_path) unload_model(model) else: print(f"[I] Loading refit weights: {refit_weights_path}") - refit_weights = torch.load(refit_weights_path) - self.engine[model_name].refit(refit_weights, obj.fp16) + refit_weights, updated_weight_names = torch.load(refit_weights_path) + self.engine[model_name].refit(refit_weights, updated_weight_names) # Load torch models for model_name, obj in self.models.items(): @@ -837,6 +828,9 @@ def encode_image(self, input_image): def decode_latent(self, latents): self.profile_start('vae', color='red') + cast_to = torch.float16 if self.models['vae'].fp16 else torch.bfloat16 if self.models['vae'].bf16 else torch.float32 + latents = latents.to(dtype=cast_to) + if self.torch_inference: images = self.torch_models['vae'](latents, return_dict=False)[0] else: diff --git a/demo/Diffusion/stable_video_diffusion_pipeline.py b/demo/Diffusion/stable_video_diffusion_pipeline.py index e2370539e..c9af7247f 100644 --- a/demo/Diffusion/stable_video_diffusion_pipeline.py +++ b/demo/Diffusion/stable_video_diffusion_pipeline.py @@ -41,7 +41,19 @@ _append_dims, _resize_with_antialiasing, tensor2vid, + load_calibration_images, ) +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +from utils_modelopt import ( + filter_func, + quantize_lvl, + check_lora, + set_fmha, + generate_fp8_scales, + SD_FP8_FP16_DEFAULT_CONFIG, +) + from stable_diffusion_pipeline import StableDiffusionPipeline class StableVideoDiffusionPipeline(StableDiffusionPipeline): @@ -151,6 +163,10 @@ def loadEngines( enable_refit=False, enable_all_tactics=False, timing_cache=None, + fp8=False, + quantization_level=0.0, + calibration_size=32, + calib_batch_size=2 ): """ Build and load engines for TensorRT accelerated inference. @@ -181,6 +197,14 @@ def loadEngines( Enable all tactic sources during TensorRT engine builds. timing_cache (str): Path to the timing cache to speed up TensorRT build. + fp8 (bool): + Whether to quantize to fp8 format or not. + quantization_level (float): + Controls which layers to quantize. + calibration_size (int): + The number of steps to use for calibrating the model for quantization. + calib_batch_size (int): + The batch size to use for calibration. Defaults to 2. """ # Create directories if missing for directory in [engine_dir, onnx_dir]: @@ -210,13 +234,83 @@ def loadEngines( engine_path = dict(zip(model_names, [self.getEnginePath(model_name, engine_dir) for model_name in model_names])) do_engine_refit = dict(zip(model_names, [enable_refit and model_name.startswith('unet') for model_name in model_names])) + # Quantization. + model_suffix = dict(zip(model_names, ['' for model_name in model_names])) + use_fp8 = dict.fromkeys(model_names, False) + if fp8: + model_name = "unet-temp" + use_fp8[model_name] = True + model_suffix[model_name] += f"-fp8.l{quantization_level}.bs2.s{self.denoising_steps}.c{calibration_size}" + onnx_path = { model_name : self.getOnnxPath(model_name, onnx_dir, opt=False, suffix=model_suffix[model_name]) for model_name in model_names } + onnx_opt_path = { model_name : self.getOnnxPath(model_name, onnx_dir, suffix=model_suffix[model_name]) for model_name in model_names } + engine_path = { model_name : self.getEnginePath(model_name, engine_dir, do_engine_refit[model_name], suffix=model_suffix[model_name]) for model_name in model_names } + weights_map_path = { model_name : (self.getWeightsMapPath(model_name, onnx_dir) if do_engine_refit[model_name] else None) for model_name in model_names } + + # Export models to ONNX for model_name, obj in self.models.items(): if self.torch_fallback[model_name]: continue do_export_onnx = not os.path.exists(engine_path[model_name]) and not os.path.exists(onnx_opt_path[model_name]) - if do_export_onnx: - obj.export_onnx(onnx_path[model_name], onnx_opt_path[model_name], onnx_opset, opt_image_height, opt_image_width) + do_export_weights_map = weights_map_path[model_name] and not os.path.exists(weights_map_path[model_name]) + if do_export_onnx or do_export_weights_map: + if use_fp8[model_name]: + pipeline = obj.get_pipeline() + model = pipeline.unet + + state_dict_path = self.getStateDictPath(model_name, onnx_dir, suffix=model_suffix[model_name]) + if not os.path.exists(state_dict_path): + # Load calibration images + print(f"[I] Calibrated weights not found, generating {state_dict_path}") + calibration_image_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'calibration-images') + calibration_image_list = load_calibration_images(calibration_image_folder) + print("Number of images loaded:", len(calibration_image_list)) + + # TODO check size > calibration_size + def do_calibrate(pipeline, calibration_images, **kwargs): + for i_th, image in enumerate(calibration_images): + if i_th >= kwargs["calib_size"]: + return + pipeline( + image=image, + num_inference_steps=kwargs["n_steps"], + ).frames[0] + + def forward_loop(model): + pipeline.unet = model + do_calibrate( + pipeline=pipeline, + calibration_images=calibration_image_list, + calib_size=calibration_size // calib_batch_size, + n_steps=self.denoising_steps, + ) + + print(f"[I] Performing calibration for {calibration_size} steps.") + if use_fp8[model_name]: + quant_config = SD_FP8_FP16_DEFAULT_CONFIG + check_lora(model) + mtq.quantize(model, quant_config, forward_loop) + mto.save(model, state_dict_path) + else: + mto.restore(model, state_dict_path) + + print(f"[I] Generating quantized ONNX model: {onnx_opt_path[model_name]}") + if not os.path.exists(onnx_path[model_name]): + """ + Error: Torch bug, ONNX export failed due to unknown kernel shape in QuantConv3d. + TRT_FP8QuantizeLinear and TRT_FP8DequantizeLinear operations in UNetSpatioTemporalConditionModel for svd + cause issues. Inputs on different devices (CUDA vs CPU) may contribute to the problem. + """ + quantize_lvl(model, quantization_level, enable_conv_3d=False) + mtq.disable_quantizer(model, filter_func) + if use_fp8[model_name]: + generate_fp8_scales(model) + else: + model = None + + obj.export_onnx(onnx_path[model_name], onnx_opt_path[model_name], onnx_opset, opt_image_height, opt_image_width, custom_model=model, static_shape=static_shape) + else: + obj.export_onnx(onnx_path[model_name], onnx_opt_path[model_name], onnx_opset, opt_image_height, opt_image_width) # Build TensorRT engines for model_name, obj in self.models.items(): diff --git a/demo/Diffusion/utilities.py b/demo/Diffusion/utilities.py index 92c9c089f..932771533 100755 --- a/demo/Diffusion/utilities.py +++ b/demo/Diffusion/utilities.py @@ -21,6 +21,7 @@ from collections import OrderedDict from cuda import cudart from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.utils import load_image from enum import Enum, auto import gc from io import BytesIO @@ -45,6 +46,7 @@ import tensorrt as trt import torch import types +import gc TRT_LOGGER = trt.Logger(trt.Logger.ERROR) @@ -141,12 +143,18 @@ def lora_forward(self, x, scale=None): new_linear._torch_forward = new_linear.forward new_linear.forward = types.MethodType(lora_forward, new_linear) -def merge_loras(model, lora_dict, lora_alphas, lora_scales): - assert len(lora_scales) == len(lora_dict) - for path, lora in lora_dict.items(): - print(f"[I] Fusing LoRA: {path}, scale {lora_scales[path]}") - model.load_attn_procs(lora, network_alphas=lora_alphas[path]) - model.fuse_lora(lora_scale=lora_scales[path]) +def merge_loras(model, lora_loader): + paths, weights, scale = lora_loader.paths, lora_loader.weights, lora_loader.scale + for i, path in enumerate(paths): + print(f"[I] Loading LoRA: {path}, weight {weights[i]}") + state_dict, network_alphas = lora_loader.lora_state_dict(path, unet_config=model.config) + lora_loader.load_lora_into_unet(state_dict, network_alphas=network_alphas, + unet=model, adapter_name=path) + + model.set_adapters(paths, weights=weights) + # NOTE: fuse_lora an experimental API in Diffusers + model.fuse_lora(adapter_names=paths, lora_scale=scale) + model.unload_lora() return model def CUASSERT(cuda_ret): @@ -219,21 +227,15 @@ def __del__(self): del self.buffers del self.tensors - def refit(self, refit_weights, is_fp16): + def refit(self, refit_weights, updated_weight_names): # Initialize refitter refitter = trt.Refitter(self.engine, TRT_LOGGER) - refitted_weights = set() - # iterate through all tensorrt refittable weights - for trt_weight_name in refitter.get_all_weights(): - if trt_weight_name not in refit_weights: - continue + def refit_single_weight(trt_weight_name): # get weight from state dict - trt_datatype = trt.DataType.FLOAT - if is_fp16: - refit_weights[trt_weight_name] = refit_weights[trt_weight_name].half() - trt_datatype = trt.DataType.HALF + trt_datatype = refitter.get_weights_prototype(trt_weight_name).dtype + refit_weights[trt_weight_name] = refit_weights[trt_weight_name].to(trt_to_torch_dtype_dict[trt_datatype]) # trt.Weight and trt.TensorLocation trt_wt_tensor = trt.Weights(trt_datatype, refit_weights[trt_weight_name].data_ptr(), torch.numel(refit_weights[trt_weight_name])) @@ -243,7 +245,17 @@ def refit(self, refit_weights, is_fp16): refitter.set_named_weights(trt_weight_name, trt_wt_tensor, trt_wt_location) refitted_weights.add(trt_weight_name) - assert set(refitted_weights) == set(refit_weights.keys()) + # iterate through all tensorrt refittable weights + for trt_weight_name in refitter.get_all_weights(): + if trt_weight_name not in updated_weight_names: + continue + + refit_single_weight(trt_weight_name) + + # iterate through missing weights required by tensorrt - addresses the case where lora_scale=0 + for trt_weight_name in refitter.get_missing_weights(): + refit_single_weight(trt_weight_name) + if not refitter.refit_cuda_engine(): print("Error: failed to refit new weights.") exit(0) @@ -306,8 +318,24 @@ def build(self, save_engine(engine, path=self.engine_path) def load(self): - print(f"Loading TensorRT engine: {self.engine_path}") - self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) + if self.engine is not None: + print(f"[W]: Engine {self.engine_path} already loaded, skip reloading") + return + if not hasattr(self,'engine_bytes_cpu') or self.engine_bytes_cpu is None: + # keep a cpu copy of the engine to reduce reloading time. + print(f"Loading TensorRT engine to cpu bytes: {self.engine_path}") + self.engine_bytes_cpu = bytes_from_path(self.engine_path) + print(f"Loading TensorRT engine from bytes: {self.engine_path}") + self.engine = engine_from_bytes(self.engine_bytes_cpu) + + def unload(self): + if self.engine is not None: + print(f"Unloading TensorRT engine: {self.engine_path}") + del self.engine + self.engine = None + gc.collect() + else: + print(f"[W]: Unload an unloaded engine {self.engine_path}, skip unloading") def activate(self, device_memory=None): if device_memory: @@ -559,6 +587,7 @@ def get_refit_weights(state_dict, onnx_opt_path, weight_name_mapping, weight_sha initializer_hash_mapping[initializer.name] = initializer_hash refit_weights = OrderedDict() + updated_weight_names = set() # save names of updated weights to refit only the required weights for wt_name, wt in state_dict.items(): # query initializer to compare initializer_name = weight_name_mapping[wt_name] @@ -574,14 +603,28 @@ def get_refit_weights(state_dict, onnx_opt_path, weight_name_mapping, weight_sha # include weight if hashes differ wt_hash = hash(wt.cpu().detach().numpy().astype(np.float16).data.tobytes()) if initializer_hash != wt_hash: - refit_weights[initializer_name] = wt.contiguous() - return refit_weights + updated_weight_names.add(initializer_name) + # Store all weights as the refitter may require unchanged weights too + # docs: https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#refitting-engine-c + refit_weights[initializer_name] = wt.contiguous() + return refit_weights, updated_weight_names def load_calib_prompts(batch_size, calib_data_path): with open(calib_data_path, "r") as file: lst = [line.rstrip("\n") for line in file] return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)] +def load_calibration_images(folder_path): + images = [] + for filename in os.listdir(folder_path): + img_path = os.path.join(folder_path, filename) + if os.path.isfile(img_path): + image = load_image(img_path) + if image is not None: + images.append(image) + return images + + class PercentileAmaxes: def __init__(self, total_step, percentile) -> None: self.data = {} @@ -609,7 +652,8 @@ def add_arguments(parser): parser.add_argument('--denoising-steps', type=int, default=30, help="Number of denoising steps") parser.add_argument('--scheduler', type=str, default=None, choices=("DDIM", "DDPM", "EulerA", "Euler", "LCM", "LMSD", "PNDM", "UniPC", "DDPMWuerstchen", "FlowMatchEuler"), help="Scheduler for diffusion process") parser.add_argument('--guidance-scale', type=float, default=7.5, help="Value of classifier-free guidance scale (must be greater than 1)") - parser.add_argument('--lora-scale', type=float, nargs='+', default=None, help="Scale of LoRA weights, default 1 (must between 0 and 1)") + parser.add_argument('--lora-scale', type=float, default=1.0, help="Controls how much to influence the outputs with the LoRA parameters. (must between 0 and 1)") + parser.add_argument('--lora-weight', type=float, nargs='+', default=None, help="The LoRA adapter(s) weights to use with the UNet. (must between 0 and 1)") parser.add_argument('--lora-path', type=str, nargs='+', default=None, help="Path to LoRA adaptor. Ex: 'latent-consistency/lcm-lora-sdv1-5'") # ONNX export @@ -666,8 +710,8 @@ def process_pipeline_args(args): if args.int8 and not any(args.version.startswith(prefix) for prefix in ['xl', '1.4', '1.5', '2.1']): raise ValueError(f"int8 quantization is only supported for SDXL, SD1.4, SD1.5 and SD2.1 pipelines.") - if args.fp8 and not any(args.version.startswith(prefix) for prefix in ['xl', '1.4', '1.5', '2.1']): - raise ValueError(f"fp8 quantization is only supported for SDXL, SD1.4, SD1.5 and SD2.1 pipelines.") + if args.fp8 and not any(args.version.startswith(prefix) for prefix in ['xl', '1.4', '1.5', '2.1', 'flux.1-dev']): + raise ValueError(f"fp8 quantization is only supported for SDXL, SD1.4, SD1.5, SD2.1 and FLUX pipelines.") if args.fp8 and args.int8: raise ValueError(f"Cannot apply both int8 and fp8 quantization, please choose only one.") @@ -675,8 +719,8 @@ def process_pipeline_args(args): if args.fp8: device_info = torch.cuda.get_device_properties(0) version = device_info.major * 10 + device_info.minor - if version < 90: # if Ada or older - raise ValueError(f"Cannot apply FP8 quantization for GPU with compute capability {version / 10.0}. Only Hopper is supported.") + if version < 89: + raise ValueError(f"Cannot apply FP8 quantization for GPU with compute capability {version / 10.0}. Only Ada and Hopper are supported.") if args.quantization_level == 0.0: def override_quant_level(level : float, dtype_str : str): @@ -684,16 +728,19 @@ def override_quant_level(level : float, dtype_str : str): print(f"The default quantization level has been set to {level} for {dtype_str}.") if args.fp8: - override_quant_level(3.0 if args.version in ("1.4", "1.5") else 4.0, "FP8") + override_quant_level(3.0 if args.version in ("1.4", "1.5", "flux.1-dev") else 4.0, "FP8") elif args.int8: override_quant_level(3.0, "INT8") if args.lora_path and not any(args.version.startswith(prefix) for prefix in ('1.5', '2.1', 'xl')): raise ValueError(f"LoRA adapter support is only supported for SD1.5, SD2.1 and SDXL pipelines") - if args.lora_scale: - for lora_scale in (lora_scale for lora_scale in args.lora_scale if not 0 <= lora_scale <= 1): - raise ValueError(f"Scale of LoRA weights must be between 0 and 1, provided {lora_scale}") + if args.lora_weight: + for weight in (weight for weight in args.lora_weight if not 0 <= weight <= 1): + raise ValueError(f"LoRA adapter weights must be between 0 and 1, provided {weight}") + + if not 0 <= args.lora_scale <= 1: + raise ValueError(f"LoRA scale value must be between 0 and 1, provided {args.lora_scale}") kwargs_init_pipeline = { 'version': args.version, @@ -707,6 +754,7 @@ def override_quant_level(level : float, dtype_str : str): 'nvtx_profile': args.nvtx_profile, 'use_cuda_graph': args.use_cuda_graph, 'lora_scale': args.lora_scale, + 'lora_weight': args.lora_weight, 'lora_path': args.lora_path, 'framework_model_dir': args.framework_model_dir, 'torch_inference': args.torch_inference, diff --git a/demo/Diffusion/utils_modelopt.py b/demo/Diffusion/utils_modelopt.py index 9e8755c58..e8b9d7896 100644 --- a/demo/Diffusion/utils_modelopt.py +++ b/demo/Diffusion/utils_modelopt.py @@ -107,7 +107,13 @@ def filter_func(name): ) return pattern.match(name) is not None -def quantize_lvl(unet, quant_level=2.5, linear_only=False): +def filter_func_no_proj_out(name): + pattern = re.compile( + r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out).*" + ) + return pattern.match(name) is not None + +def quantize_lvl(unet, quant_level=2.5, linear_only=False, enable_conv_3d=True): """ We should disable the unwanted quantizer when exporting the onnx Because in the current modelopt setting, it will load the quantizer amax for all the layers even @@ -132,6 +138,14 @@ def quantize_lvl(unet, quant_level=2.5, linear_only=False): else: module.input_quantizer.disable() module.weight_quantizer.disable() + elif isinstance(module, torch.nn.Conv3d) and not enable_conv_3d: + """ + Error: Torch bug, ONNX export failed due to unknown kernel shape in QuantConv3d. + TRT_FP8QuantizeLinear and TRT_FP8DequantizeLinear operations in UNetSpatioTemporalConditionModel for svd + cause issues. Inputs on different devices (CUDA vs CPU) may contribute to the problem. + """ + module.input_quantizer.disable() + module.weight_quantizer.disable() elif isinstance(module, Attention): # TRT only supports FP8 MHA with head_size % 16 == 0. head_size = int(module.inner_dim / module.heads) @@ -215,6 +229,25 @@ def get_int8_config( "algorithm": "max", } +SD_FP8_BF16_DEFAULT_CONFIG = { + "quant_cfg": { + "*weight_quantizer": {"num_bits": (4, 3), "axis": None, "trt_high_precision_dtype": "BFloat16"}, + "*input_quantizer": {"num_bits": (4, 3), "axis": None, "trt_high_precision_dtype": "BFloat16"}, + "*output_quantizer": {"enable": False}, + "*q_bmm_quantizer": {"num_bits": (4, 3), "axis": None, "trt_high_precision_dtype": "BFloat16"}, + "*k_bmm_quantizer": {"num_bits": (4, 3), "axis": None, "trt_high_precision_dtype": "BFloat16"}, + "*v_bmm_quantizer": {"num_bits": (4, 3), "axis": None, "trt_high_precision_dtype": "BFloat16"}, + "*softmax_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "BFloat16", + }, + "default": {"enable": False}, + }, + "algorithm": "max", +} + + SD_FP8_FP32_DEFAULT_CONFIG = { "quant_cfg": { "*weight_quantizer": {"num_bits": (4, 3), "axis": None, "trt_high_precision_dtype": "Float"}, @@ -465,6 +498,11 @@ def cast_fp8_mha_io(graph): insert_cast(graph, input_tensor=bmm2_node.inputs[1], attrs={"to": np.float32}) insert_cast(graph, input_tensor=bmm2_node.outputs[0], attrs={"to": np.float16}) +def set_quant_precision(quant_config, precision: str = "Half"): + for key in quant_config["quant_cfg"]: + if "trt_high_precision_dtype" in quant_config["quant_cfg"][key]: + quant_config["quant_cfg"][key]["trt_high_precision_dtype"] = precision + def convert_fp16_io(graph): """ Convert graph I/O to FP16. diff --git a/docker/rockylinux8.Dockerfile b/docker/rockylinux8.Dockerfile index ffe260343..70c5a0a65 100644 --- a/docker/rockylinux8.Dockerfile +++ b/docker/rockylinux8.Dockerfile @@ -25,7 +25,7 @@ ENV NV_CUDNN_VERSION 8.9.6.50-1 ENV NV_CUDNN_PACKAGE libcudnn8-${NV_CUDNN_VERSION}.cuda12.2 ENV NV_CUDNN_PACKAGE_DEV libcudnn8-devel-${NV_CUDNN_VERSION}.cuda12.2 -ENV TRT_VERSION 10.5.0.18 +ENV TRT_VERSION 10.6.0.26 SHELL ["/bin/bash", "-c"] RUN dnf install -y \ @@ -62,15 +62,15 @@ RUN dnf install -y python38 python38-devel &&\ # Install TensorRT RUN if [ "${CUDA_VERSION:0:2}" = "11" ]; then \ - wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-11.8.tar.gz \ - && tar -xf TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-11.8.tar.gz \ - && cp -a TensorRT-10.5.0.18/lib/*.so* /usr/lib64 \ - && pip install TensorRT-10.5.0.18/python/tensorrt-10.5.0-cp38-none-linux_x86_64.whl ;\ + wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.6.0/tars/TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \ + && tar -xf TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \ + && cp -a TensorRT-10.6.0.26/lib/*.so* /usr/lib64 \ + && pip install TensorRT-10.6.0.26/python/tensorrt-10.6.0-cp38-none-linux_x86_64.whl ;\ elif [ "${CUDA_VERSION:0:2}" = "12" ]; then \ - wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz \ - && tar -xf TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz \ - && cp -a TensorRT-10.5.0.18/lib/*.so* /usr/lib64 \ - && pip install TensorRT-10.5.0.18/python/tensorrt-10.5.0-cp38-none-linux_x86_64.whl ;\ + wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.6.0/tars/TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \ + && tar -xf TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \ + && cp -a TensorRT-10.6.0.26/lib/*.so* /usr/lib64 \ + && pip install TensorRT-10.6.0.26/python/tensorrt-10.6.0-cp38-none-linux_x86_64.whl ;\ else \ echo "Invalid CUDA_VERSION"; \ exit 1; \ diff --git a/docker/rockylinux9.Dockerfile b/docker/rockylinux9.Dockerfile index bed779a79..70994b921 100644 --- a/docker/rockylinux9.Dockerfile +++ b/docker/rockylinux9.Dockerfile @@ -25,7 +25,7 @@ ENV NV_CUDNN_VERSION 8.9.6.50-1 ENV NV_CUDNN_PACKAGE libcudnn8-${NV_CUDNN_VERSION}.cuda12.2 ENV NV_CUDNN_PACKAGE_DEV libcudnn8-devel-${NV_CUDNN_VERSION}.cuda12.2 -ENV TRT_VERSION 10.5.0.18 +ENV TRT_VERSION 10.6.0.26 SHELL ["/bin/bash", "-c"] RUN dnf install -y \ @@ -67,15 +67,15 @@ RUN dnf -y install \ # Install TensorRT RUN if [ "${CUDA_VERSION:0:2}" = "11" ]; then \ - wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-11.8.tar.gz \ - && tar -xf TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-11.8.tar.gz \ - && cp -a TensorRT-10.5.0.18/lib/*.so* /usr/lib64 \ - && pip install TensorRT-10.5.0.18/python/tensorrt-10.5.0-cp39-none-linux_x86_64.whl ;\ + wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.6.0/tars/TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \ + && tar -xf TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \ + && cp -a TensorRT-10.6.0.26/lib/*.so* /usr/lib64 \ + && pip install TensorRT-10.6.0.26/python/tensorrt-10.6.0-cp39-none-linux_x86_64.whl ;\ elif [ "${CUDA_VERSION:0:2}" = "12" ]; then \ - wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz \ - && tar -xf TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz \ - && cp -a TensorRT-10.5.0.18/lib/*.so* /usr/lib64 \ - && pip install TensorRT-10.5.0.18/python/tensorrt-10.5.0-cp39-none-linux_x86_64.whl ;\ + wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.6.0/tars/TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \ + && tar -xf TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \ + && cp -a TensorRT-10.6.0.26/lib/*.so* /usr/lib64 \ + && pip install TensorRT-10.6.0.26/python/tensorrt-10.6.0-cp39-none-linux_x86_64.whl ;\ else \ echo "Invalid CUDA_VERSION"; \ exit 1; \ diff --git a/docker/ubuntu-20.04.Dockerfile b/docker/ubuntu-20.04.Dockerfile index a2ebb6057..939eb89d3 100644 --- a/docker/ubuntu-20.04.Dockerfile +++ b/docker/ubuntu-20.04.Dockerfile @@ -28,7 +28,7 @@ ENV CUDA_VERSION_MAJOR_MINOR=12.2 ENV NV_CUDNN_PACKAGE "libcudnn8=$NV_CUDNN_VERSION-1+cuda${CUDA_VERSION_MAJOR_MINOR}" ENV NV_CUDNN_PACKAGE_DEV "libcudnn8-dev=$NV_CUDNN_VERSION-1+cuda${CUDA_VERSION_MAJOR_MINOR}" -ENV TRT_VERSION 10.5.0.18 +ENV TRT_VERSION 10.6.0.26 SHELL ["/bin/bash", "-c"] RUN apt-get update && apt-get install -y --no-install-recommends \ @@ -84,15 +84,15 @@ RUN apt-get install -y --no-install-recommends \ # Install TensorRT RUN if [ "${CUDA_VERSION:0:2}" = "11" ]; then \ - wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-11.8.tar.gz \ - && tar -xf TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-11.8.tar.gz \ - && cp -a TensorRT-10.5.0.18/lib/*.so* /usr/lib/x86_64-linux-gnu \ - && pip install TensorRT-10.5.0.18/python/tensorrt-10.5.0-cp38-none-linux_x86_64.whl ;\ + wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.6.0/tars/TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \ + && tar -xf TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \ + && cp -a TensorRT-10.6.0.26/lib/*.so* /usr/lib/x86_64-linux-gnu \ + && pip install TensorRT-10.6.0.26/python/tensorrt-10.6.0-cp38-none-linux_x86_64.whl ;\ elif [ "${CUDA_VERSION:0:2}" = "12" ]; then \ - wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz \ - && tar -xf TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz \ - && cp -a TensorRT-10.5.0.18/lib/*.so* /usr/lib/x86_64-linux-gnu \ - && pip install TensorRT-10.5.0.18/python/tensorrt-10.5.0-cp38-none-linux_x86_64.whl ;\ + wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.6.0/tars/TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \ + && tar -xf TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \ + && cp -a TensorRT-10.6.0.26/lib/*.so* /usr/lib/x86_64-linux-gnu \ + && pip install TensorRT-10.6.0.26/python/tensorrt-10.6.0-cp38-none-linux_x86_64.whl ;\ else \ echo "Invalid CUDA_VERSION"; \ exit 1; \ diff --git a/docker/ubuntu-22.04-aarch64.Dockerfile b/docker/ubuntu-22.04-aarch64.Dockerfile index e28f058c5..bd28c2bf7 100644 --- a/docker/ubuntu-22.04-aarch64.Dockerfile +++ b/docker/ubuntu-22.04-aarch64.Dockerfile @@ -20,7 +20,7 @@ ARG CUDA_VERSION=12.6.0 # Multi-arch container support available in non-cudnn containers. FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 -ENV TRT_VERSION 10.5.0.18 +ENV TRT_VERSION 10.6.0.26 SHELL ["/bin/bash", "-c"] # Setup user account diff --git a/docker/ubuntu-22.04.Dockerfile b/docker/ubuntu-22.04.Dockerfile index 0bb09b2d5..e72671bad 100644 --- a/docker/ubuntu-22.04.Dockerfile +++ b/docker/ubuntu-22.04.Dockerfile @@ -28,7 +28,7 @@ ENV CUDA_VERSION_MAJOR_MINOR=12.2 ENV NV_CUDNN_PACKAGE "libcudnn8=$NV_CUDNN_VERSION-1+cuda${CUDA_VERSION_MAJOR_MINOR}" ENV NV_CUDNN_PACKAGE_DEV "libcudnn8-dev=$NV_CUDNN_VERSION-1+cuda${CUDA_VERSION_MAJOR_MINOR}" -ENV TRT_VERSION 10.5.0.18 +ENV TRT_VERSION 10.6.0.26 SHELL ["/bin/bash", "-c"] RUN apt-get update && apt-get install -y --no-install-recommends \ @@ -84,15 +84,15 @@ RUN apt-get install -y --no-install-recommends \ # Install TensorRT RUN if [ "${CUDA_VERSION:0:2}" = "11" ]; then \ - wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-11.8.tar.gz \ - && tar -xf TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-11.8.tar.gz \ - && cp -a TensorRT-10.5.0.18/lib/*.so* /usr/lib/x86_64-linux-gnu \ - && pip install TensorRT-10.5.0.18/python/tensorrt-10.5.0-cp310-none-linux_x86_64.whl ;\ + wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.6.0/tars/TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \ + && tar -xf TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-11.8.tar.gz \ + && cp -a TensorRT-10.6.0.26/lib/*.so* /usr/lib/x86_64-linux-gnu \ + && pip install TensorRT-10.6.0.26/python/tensorrt-10.6.0-cp310-none-linux_x86_64.whl ;\ elif [ "${CUDA_VERSION:0:2}" = "12" ]; then \ - wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz \ - && tar -xf TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz \ - && cp -a TensorRT-10.5.0.18/lib/*.so* /usr/lib/x86_64-linux-gnu \ - && pip install TensorRT-10.5.0.18/python/tensorrt-10.5.0-cp310-none-linux_x86_64.whl ;\ + wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.6.0/tars/TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \ + && tar -xf TensorRT-10.6.0.26.Linux.x86_64-gnu.cuda-12.6.tar.gz \ + && cp -a TensorRT-10.6.0.26/lib/*.so* /usr/lib/x86_64-linux-gnu \ + && pip install TensorRT-10.6.0.26/python/tensorrt-10.6.0-cp310-none-linux_x86_64.whl ;\ else \ echo "Invalid CUDA_VERSION"; \ exit 1; \ diff --git a/docker/ubuntu-cross-aarch64.Dockerfile b/docker/ubuntu-cross-aarch64.Dockerfile index 3243b4a7f..8e3c3845d 100644 --- a/docker/ubuntu-cross-aarch64.Dockerfile +++ b/docker/ubuntu-cross-aarch64.Dockerfile @@ -21,7 +21,7 @@ ARG OS_VERSION=22.04 FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${OS_VERSION} LABEL maintainer="NVIDIA CORPORATION" -ENV TRT_VERSION 10.5.0.18 +ENV TRT_VERSION 10.6.0.26 ENV DEBIAN_FRONTEND=noninteractive ARG uid=1000 diff --git a/include/NvInfer.h b/include/NvInfer.h index e0231d4d1..61d71ecc0 100644 --- a/include/NvInfer.h +++ b/include/NvInfer.h @@ -8434,6 +8434,19 @@ enum class BuilderFlag : int32_t //! enabled. This flag cannot be used together with kREFIT or kREFIT_IDENTICAL. kREFIT_INDIVIDUAL = 23, + //! Disable floating-point optimizations: 0*x => 0, x-x => 0, or x/x => 1. These identities are + //! not true when x is a NaN or Inf, and thus might hide propagation or generation of NaNs. This flag is typically + //! used in combination with kSPARSE_WEIGHTS. + //! There are three valid sparsity configurations. + //! 1. Disable all sparsity. Both kSPARSE_WEIGHTS and kSTRICT_NANS are unset + //! 2. Enable sparsity only where it does not affect propagation/generation of NaNs. Both kSPARSE_WEIGHTS and + //! kSTRICT_NANS are set + //! 3. Enable all sparsity. kSPARSE_WEIGHTS is set and kSTRICT_NANS is unset + kSTRICT_NANS = 24, + + //! Enable memory monitor during build time. + kMONITOR_MEMORY = 25, + }; //! @@ -8444,7 +8457,7 @@ enum class BuilderFlag : int32_t template <> constexpr inline int32_t EnumMax() noexcept { - return 24; + return 26; } //! @@ -9024,9 +9037,9 @@ class IBuilderConfig : public INoCopy } //! - //! \brief Set the cuda stream that is used to profile this network. + //! \brief Set the CUDA stream that is used to profile this network. //! - //! \param stream The cuda stream used for profiling by the builder. + //! \param stream The CUDA stream used for profiling by the builder. //! //! \see getProfileStream() //! @@ -9036,9 +9049,9 @@ class IBuilderConfig : public INoCopy } //! - //! \brief Get the cuda stream that is used to profile this network. + //! \brief Get the CUDA stream that is used to profile this network. //! - //! \return The cuda stream set by setProfileStream, nullptr if setProfileStream has not been called. + //! \return The CUDA stream set by setProfileStream, nullptr if setProfileStream has not been called. //! //! \see setProfileStream() //! @@ -9838,7 +9851,7 @@ class IBuilder : public INoCopy //! //! \return A pointer to a IHostMemory object that contains a serialized network. //! - //! \note This function will synchronize the cuda stream returned by \p config.getProfileStream() before returning. + //! \note This function will synchronize the CUDA stream returned by \p config.getProfileStream() before returning. //! //! \see INetworkDefinition, IBuilderConfig, IHostMemory //! @@ -9847,6 +9860,26 @@ class IBuilder : public INoCopy return mImpl->buildSerializedNetwork(network, config); } + //! + //! \brief Builds a network for the given INetworkDefinition and IBuilderConfig. + //! + //! \param network Network definition. + //! \param config Builder configuration. + //! + //! \return A pointer to a ICudaEngine object that contains an engine. + //! + //! \note This function will synchronize the CUDA stream returned by \p config.getProfileStream() before returning. + //! + //! \note This function does not support \p BuilderFlag::kVERSION_COMPATIBLE. + //! Please use \p buildSerializedNetwork to get a version compatible engine. + //! + //! \see INetworkDefinition, IBuilderConfig, ICudaEngine + //! + nvinfer1::ICudaEngine* buildEngineWithConfig(INetworkDefinition& network, IBuilderConfig& config) noexcept + { + return mImpl->buildEngineWithConfig(network, config); + } + //! //! \brief Checks that a network is within the scope of the IBuilderConfig settings. //! @@ -9862,7 +9895,7 @@ class IBuilder : public INoCopy //! \return True if network is within the scope of the restrictions specified by the builder config, //! false otherwise. //! - //! \note This function will synchronize the cuda stream returned by \p config.getProfileStream() before returning. + //! \note This function will synchronize the CUDA stream returned by \p config.getProfileStream() before returning. //! bool isNetworkSupported(INetworkDefinition const& network, IBuilderConfig const& config) const noexcept { diff --git a/include/NvInferImpl.h b/include/NvInferImpl.h index 2c7df74af..3bb39fa40 100644 --- a/include/NvInferImpl.h +++ b/include/NvInferImpl.h @@ -26,6 +26,8 @@ namespace nvinfer1 { +class ILogger; + namespace v_1_0 { class IProgressMonitor; @@ -113,6 +115,12 @@ class IPluginV3; } // namespace v_1_0 using IPluginV3 = v_1_0::IPluginV3; +namespace v_1_0 +{ +class IStreamReader; +} // namespace v_1_0 +using IStreamReader = v_1_0::IStreamReader; + class IPluginV3Layer; class IPoolingLayer; class IQuantizeLayer; @@ -1199,13 +1207,14 @@ class VBuilder : public VRoot virtual IErrorRecorder* getErrorRecorder() const noexcept = 0; virtual void reset() noexcept = 0; virtual bool platformHasTf32() const noexcept = 0; - virtual nvinfer1::IHostMemory* buildSerializedNetwork(INetworkDefinition& network, IBuilderConfig& config) noexcept - = 0; + virtual nvinfer1::IHostMemory* buildSerializedNetwork( + INetworkDefinition& network, IBuilderConfig& config) noexcept = 0; virtual bool isNetworkSupported(INetworkDefinition const& network, IBuilderConfig const& config) const noexcept = 0; virtual ILogger* getLogger() const noexcept = 0; virtual bool setMaxThreads(int32_t maxThreads) noexcept = 0; virtual int32_t getMaxThreads() const noexcept = 0; virtual IPluginRegistry& getPluginRegistry() noexcept = 0; + virtual ICudaEngine* buildEngineWithConfig(INetworkDefinition& network, IBuilderConfig& config) noexcept = 0; }; } // namespace apiv diff --git a/include/NvInferLegacyDims.h b/include/NvInferLegacyDims.h index 2725d184f..ecedf55a5 100644 --- a/include/NvInferLegacyDims.h +++ b/include/NvInferLegacyDims.h @@ -18,9 +18,9 @@ #ifndef NV_INFER_LEGACY_DIMS_H #define NV_INFER_LEGACY_DIMS_H -#define NV_INFER_INTERNAL_INCLUDE_RUNTIME_BASE 1 +#define NV_INFER_INTERNAL_INCLUDE 1 #include "NvInferRuntimeBase.h" -#undef NV_INFER_INTERNAL_INCLUDE_RUNTIME_BASE +#undef NV_INFER_INTERNAL_INCLUDE //! //! \file NvInferLegacyDims.h diff --git a/include/NvInferPluginBase.h b/include/NvInferPluginBase.h new file mode 100644 index 000000000..d337f48ab --- /dev/null +++ b/include/NvInferPluginBase.h @@ -0,0 +1,372 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NV_INFER_PLUGIN_BASE_H +#define NV_INFER_PLUGIN_BASE_H + +#if !defined(NV_INFER_INTERNAL_INCLUDE) +static_assert(false, "Do not directly include this file. Include NvInferRuntime.h or NvInferPluginUtils.h"); +#endif + +#define NV_INFER_INTERNAL_INCLUDE 1 +#include "NvInferRuntimeBase.h" +#undef NV_INFER_INTERNAL_INCLUDE +namespace nvinfer1 +{ + +//! +//! \enum PluginFieldType +//! +//! \brief The possible field types for custom layer. +//! +enum class PluginFieldType : int32_t +{ + //! FP16 field type. + kFLOAT16 = 0, + //! FP32 field type. + kFLOAT32 = 1, + //! FP64 field type. + kFLOAT64 = 2, + //! INT8 field type. + kINT8 = 3, + //! INT16 field type. + kINT16 = 4, + //! INT32 field type. + kINT32 = 5, + //! char field type. + kCHAR = 6, + //! nvinfer1::Dims field type. + kDIMS = 7, + //! Unknown field type. + kUNKNOWN = 8, + //! BF16 field type. + kBF16 = 9, + //! INT64 field type. + kINT64 = 10, + //! FP8 field type. + kFP8 = 11, + //! INT4 field type. + kINT4 = 12, +}; + +//! +//! \class PluginField +//! +//! \brief Structure containing plugin attribute field names and associated data +//! This information can be parsed to decode necessary plugin metadata +//! +//! +class PluginField +{ +public: + //! Plugin field attribute name + AsciiChar const* name; + //! Plugin field attribute data + void const* data; + //! Plugin field attribute type + PluginFieldType type; + //! Number of data entries in the Plugin attribute + int32_t length; + + PluginField(AsciiChar const* const name_ = nullptr, void const* const data_ = nullptr, + PluginFieldType const type_ = PluginFieldType::kUNKNOWN, int32_t const length_ = 0) noexcept + : name(name_) + , data(data_) + , type(type_) + , length(length_) + { + } +}; + +//! +//! \struct PluginFieldCollection +//! +//! \brief Plugin field collection struct. +//! +struct PluginFieldCollection +{ + //! Number of PluginField entries. + int32_t nbFields{}; + //! Pointer to PluginField entries. + PluginField const* fields{}; +}; + +//! +//! \enum TensorRTPhase +//! +//! \brief Indicates a phase of operation of TensorRT +//! +enum class TensorRTPhase : int32_t +{ + //! Build phase of TensorRT + kBUILD = 0, + //! Execution phase of TensorRT + kRUNTIME = 1 +}; + +//! +//! \enum PluginCapabilityType +//! +//! \brief Enumerates the different capability types a IPluginV3 object may have +//! +enum class PluginCapabilityType : int32_t +{ + //! Core capability. Every IPluginV3 object must have this. + kCORE = 0, + //! Build capability. IPluginV3 objects provided to TensorRT build phase must have this. + kBUILD = 1, + //! Runtime capability. IPluginV3 objects provided to TensorRT build and execution phases must have this. + kRUNTIME = 2 +}; + +namespace v_1_0 +{ +class IPluginCapability : public IVersionedInterface +{ +}; + +class IPluginResource : public IVersionedInterface +{ +public: + //! + //! \brief Return version information associated with this interface. Applications must not override this method. + //! + InterfaceInfo getInterfaceInfo() const noexcept override + { + return InterfaceInfo{"IPluginResource", 1, 0}; + } + //! + //! \brief Free the underlying resource + //! + //! This will only be called for IPluginResource objects that were produced from IPluginResource::clone() + //! + //! The IPluginResource object on which release() is called must still be in a clone-able state + //! after release() returns + //! + //! \return 0 for success, else non-zero + //! \usage + //! - Allowed context for the API call + //! - Thread-safe: No; this method is not required to be thread-safe + //! + virtual int32_t release() noexcept = 0; + + //! + //! \brief Clone the resource object + //! + //! \note Resource initialization (if any) may be skipped for non-cloned objects since only clones will be + //! registered by TensorRT + //! + //! \return Pointer to cloned object. nullptr if there was an issue. + //! + //! \usage + //! - Allowed context for the API call + //! - Thread-safe: Yes; this method is required to be thread-safe and may be called from multiple threads. + //! + virtual IPluginResource* clone() noexcept = 0; + + ~IPluginResource() noexcept override = default; + + IPluginResource() = default; + IPluginResource(IPluginResource const&) = default; + IPluginResource(IPluginResource&&) = default; + IPluginResource& operator=(IPluginResource const&) & = default; + IPluginResource& operator=(IPluginResource&&) & = default; +}; // class IPluginResource + +class IPluginCreatorInterface : public IVersionedInterface +{ +public: + ~IPluginCreatorInterface() noexcept override = default; + +protected: + IPluginCreatorInterface() = default; + IPluginCreatorInterface(IPluginCreatorInterface const&) = default; + IPluginCreatorInterface(IPluginCreatorInterface&&) = default; + IPluginCreatorInterface& operator=(IPluginCreatorInterface const&) & = default; + IPluginCreatorInterface& operator=(IPluginCreatorInterface&&) & = default; +}; + +class IPluginV3 : public IVersionedInterface +{ +public: + //! + //! \brief Return version information associated with this interface. Applications must not override this method. + //! + InterfaceInfo getInterfaceInfo() const noexcept override + { + return InterfaceInfo{"PLUGIN", 1, 0}; + } + + //! \brief Return a pointer to plugin object implementing the specified PluginCapabilityType. + //! + //! \note IPluginV3 objects added for the build phase (through addPluginV3()) must return valid objects for + //! PluginCapabilityType::kCORE, PluginCapabilityType::kBUILD and PluginCapabilityType::kRUNTIME. + //! + //! \note IPluginV3 objects added for the runtime phase must return valid objects for + //! PluginCapabilityType::kCORE and PluginCapabilityType::kRUNTIME. + //! + //! \see TensorRTPhase + //! \see IPluginCreatorV3One::createPlugin() + //! + virtual IPluginCapability* getCapabilityInterface(PluginCapabilityType type) noexcept = 0; + + //! + //! \brief Clone the plugin object. This copies over internal plugin parameters and returns a new plugin object with + //! these parameters. The cloned object must be in a fully initialized state. + //! + //! \note The cloned object must return valid objects through getCapabilityInterface() for at least the same + //! PluginCapabilityTypes as the original object. + //! + //! \return A cloned plugin object in an initialized state with the same parameters as the current object. + //! nullptr must be returned if the cloning fails. + //! + virtual IPluginV3* clone() noexcept = 0; +}; + +class IPluginCreatorV3One : public IPluginCreatorInterface +{ +public: + //! + //! \brief Return version information associated with this interface. Applications must not override this method. + //! + InterfaceInfo getInterfaceInfo() const noexcept override + { + return InterfaceInfo{"PLUGIN CREATOR_V3ONE", 1, 0}; + } + + //! + //! \brief Return a plugin object. Return nullptr in case of error. + //! + //! \param name A NULL-terminated name string of length 1024 or less, including the NULL terminator. + //! \param fc A pointer to a collection of fields needed for constructing the plugin. + //! \param phase The TensorRT phase in which the plugin is being created + //! + //! When the phase is TensorRTPhase::kRUNTIME, the PluginFieldCollection provided for serialization by the plugin's + //! runtime interface will be passed as fc. + //! + //! \note The returned plugin object must be in an initialized state + //! + //! \note If invoked by the user (e.g. with TensorRTPhase::kBUILD, to add to the network defintion with + //! addPluginV3()), it is the user's responsibility to delete the plugin object. If invoked by TensorRT (e.g. during + //! engine deserialization), TensorRT will delete any objects it creates. + //! + virtual IPluginV3* createPlugin( + AsciiChar const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept = 0; + + //! + //! \brief Return a list of fields that need to be passed to createPlugin() when creating a plugin for use in the + //! TensorRT build phase. + //! + //! \see PluginFieldCollection + //! + virtual PluginFieldCollection const* getFieldNames() noexcept = 0; + + //! + //! \brief Return the plugin name. + //! + //! \warning The string returned must be NULL-terminated and have a length of 1024 bytes or less including + //! the NULL terminator. + //! + virtual AsciiChar const* getPluginName() const noexcept = 0; + + //! + //! \brief Return the plugin version. + //! + //! \warning The string returned must be NULL-terminated and have a length of 1024 bytes or less including + //! the NULL terminator. + //! + virtual AsciiChar const* getPluginVersion() const noexcept = 0; + + //! + //! \brief Return the plugin namespace. + //! + //! \warning The string returned must be NULL-terminated and have a length of 1024 bytes or less including + //! the NULL terminator. + //! + virtual AsciiChar const* getPluginNamespace() const noexcept = 0; + + IPluginCreatorV3One() = default; + virtual ~IPluginCreatorV3One() = default; + +protected: + IPluginCreatorV3One(IPluginCreatorV3One const&) = default; + IPluginCreatorV3One(IPluginCreatorV3One&&) = default; + IPluginCreatorV3One& operator=(IPluginCreatorV3One const&) & = default; + IPluginCreatorV3One& operator=(IPluginCreatorV3One&&) & = default; +}; + +} // namespace v_1_0 + +//! +//! \class IPluginCreatorV3One +//! +//! \brief A plugin creator class capable of producing IPluginV3 objects +//! +//! \see IPluginV3 +//! \see IPluginRegistry +//! +using IPluginCreatorV3One = v_1_0::IPluginCreatorV3One; + +//! +//! \class IPluginResource +//! +//! \brief Interface for plugins to define custom resources that could be shared through the plugin registry +//! +//! \see IPluginRegistry::acquirePluginResource +//! \see IPluginRegistry::releasePluginResource +//! +using IPluginResource = v_1_0::IPluginResource; + +//! +//! \class IPluginCreatorInterface +//! +//! \brief Base class for all plugin creator versions. +//! +//! \see IPluginCreator and IPluginRegistry +//! +using IPluginCreatorInterface = v_1_0::IPluginCreatorInterface; + +//! +//! \class IPluginV3 +//! +//! \brief Plugin class for the V3 generation of user-implemented layers. +//! +//! IPluginV3 acts as a wrapper around the plugin capability interfaces that define the actual behavior of the plugin. +//! +//! \see IPluginCapability +//! \see IPluginCreatorV3One +//! \see IPluginRegistry +//! +using IPluginV3 = v_1_0::IPluginV3; + +//! +//! \class IPluginCapability +//! +//! \brief Base class for plugin capability interfaces +//! +//! IPluginCapability represents a split in TensorRT V3 plugins to sub-objects that expose different types of +//! capabilites a plugin may have, as opposed to a single interface which defines all capabilities and behaviors of a +//! plugin. +//! +//! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI. +//! +//! \see PluginCapabilityType +//! +using IPluginCapability = v_1_0::IPluginCapability; +} // namespace nvinfer1 + +#endif /* NV_INFER_PLUGIN_BASE_H */ diff --git a/include/NvInferRuntime.h b/include/NvInferRuntime.h index 485628a68..a9e607195 100644 --- a/include/NvInferRuntime.h +++ b/include/NvInferRuntime.h @@ -25,6 +25,9 @@ //! #include "NvInferImpl.h" +#define NV_INFER_INTERNAL_INCLUDE 1 +#include "NvInferPluginBase.h" +#undef NV_INFER_INTERNAL_INCLUDE #include "NvInferRuntimeCommon.h" namespace nvinfer1 @@ -622,6 +625,55 @@ class TRT_DEPRECATED IPluginV2DynamicExt : public nvinfer1::IPluginV2Ext } }; +namespace v_1_0 +{ +class IStreamReader : public IVersionedInterface +{ +public: + //! + //! TensorRT never calls the destructor for an IStreamReader defined by the + //! application. + //! + ~IStreamReader() override = default; + IStreamReader() = default; + + //! + //! \brief Return version information associated with this interface. Applications must not override this method. + //! + InterfaceInfo getInterfaceInfo() const noexcept override + { + return InterfaceInfo{"IStreamReader", 1, 0}; + } + + //! + //! \brief Read the next number of bytes in the stream. + //! + //! \param destination The memory to write to + //! \param nbBytes The number of bytes to read + //! + //! \returns The number of bytes read. Negative values will be considered an automatic error. + //! + virtual int64_t read(void* destination, int64_t nbBytes) = 0; + +protected: + IStreamReader(IStreamReader const&) = default; + IStreamReader(IStreamReader&&) = default; + IStreamReader& operator=(IStreamReader const&) & = default; + IStreamReader& operator=(IStreamReader&&) & = default; +}; +} // namespace v_1_0 + +//! +//! \class IStreamReader +//! +//! \brief Application-implemented class for reading data in a stream-based manner. +//! +//! \note To ensure compatibility of source code with future versions of TensorRT, use IStreamReader, not +//! v_1_0::IStreamReader +//! +using IStreamReader = v_1_0::IStreamReader; + + //! //! \class IPluginResourceContext //! @@ -659,82 +711,6 @@ class IPluginResourceContext IPluginResourceContext& operator=(IPluginResourceContext&&) & = default; }; -namespace v_1_0 -{ -class IPluginCapability : public IVersionedInterface -{ -}; -} // namespace v_1_0 - -//! -//! \class IPluginCapability -//! -//! \brief Base class for plugin capability interfaces -//! -//! IPluginCapability represents a split in TensorRT V3 plugins to sub-objects that expose different types of -//! capabilites a plugin may have, as opposed to a single interface which defines all capabilities and behaviors of a -//! plugin. -//! -//! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI. -//! -//! \see PluginCapabilityType -//! -using IPluginCapability = v_1_0::IPluginCapability; - -namespace v_1_0 -{ -class IPluginV3 : public IVersionedInterface -{ -public: - //! - //! \brief Return version information associated with this interface. Applications must not override this method. - //! - InterfaceInfo getInterfaceInfo() const noexcept override - { - return InterfaceInfo{"PLUGIN", 1, 0}; - } - - //! \brief Return a pointer to plugin object implementing the specified PluginCapabilityType. - //! - //! \note IPluginV3 objects added for the build phase (through addPluginV3()) must return valid objects for - //! PluginCapabilityType::kCORE, PluginCapabilityType::kBUILD and PluginCapabilityType::kRUNTIME. - //! - //! \note IPluginV3 objects added for the runtime phase must return valid objects for - //! PluginCapabilityType::kCORE and PluginCapabilityType::kRUNTIME. - //! - //! \see TensorRTPhase - //! \see IPluginCreatorV3One::createPlugin() - //! - virtual IPluginCapability* getCapabilityInterface(PluginCapabilityType type) noexcept = 0; - - //! - //! \brief Clone the plugin object. This copies over internal plugin parameters and returns a new plugin object with - //! these parameters. The cloned object must be in a fully initialized state. - //! - //! \note The cloned object must return valid objects through getCapabilityInterface() for at least the same - //! PluginCapabilityTypes as the original object. - //! - //! \return A cloned plugin object in an initialized state with the same parameters as the current object. - //! nullptr must be returned if the cloning fails. - //! - virtual IPluginV3* clone() noexcept = 0; -}; - -} // namespace v_1_0 - -//! -//! \class IPluginV3 -//! -//! \brief Plugin class for the V3 generation of user-implemented layers. -//! -//! IPluginV3 acts as a wrapper around the plugin capability interfaces that define the actual behavior of the plugin. -//! -//! \see IPluginCapability -//! \see IPluginCreatorV3One -//! \see IPluginRegistry -//! -using IPluginV3 = v_1_0::IPluginV3; - namespace v_1_0 { class IPluginV3OneCore : public IPluginCapability @@ -815,6 +791,8 @@ class IPluginV3OneBuild : public IPluginCapability //! \param out The output tensors attributes that are used for configuration. //! \param nbOutputs Number of output tensors. //! + //! \return 0 for success, else non-zero (which will cause engine termination, if invoked by TensorRT). + //! virtual int32_t configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept = 0; @@ -1185,87 +1163,6 @@ using IPluginV3OneRuntime = v_1_0::IPluginV3OneRuntime; //! using IPluginV3OneBuildV2 = v_2_0::IPluginV3OneBuild; -namespace v_1_0 -{ -class IPluginCreatorV3One : public IPluginCreatorInterface -{ -public: - //! - //! \brief Return version information associated with this interface. Applications must not override this method. - //! - InterfaceInfo getInterfaceInfo() const noexcept override - { - return InterfaceInfo{"PLUGIN CREATOR_V3ONE", 1, 0}; - } - - //! - //! \brief Return a plugin object. Return nullptr in case of error. - //! - //! \param name A NULL-terminated name string of length 1024 or less, including the NULL terminator. - //! \param fc A pointer to a collection of fields needed for constructing the plugin. - //! \param phase The TensorRT phase in which the plugin is being created - //! - //! When the phase is TensorRTPhase::kRUNTIME, the PluginFieldCollection provided for serialization by the plugin's - //! runtime interface will be passed as fc. - //! - //! \note The returned plugin object must be in an initialized state - //! - virtual IPluginV3* createPlugin( - AsciiChar const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept = 0; - - //! - //! \brief Return a list of fields that need to be passed to createPlugin() when creating a plugin for use in the - //! TensorRT build phase. - //! - //! \see PluginFieldCollection - //! - virtual PluginFieldCollection const* getFieldNames() noexcept = 0; - - //! - //! \brief Return the plugin name. - //! - //! \warning The string returned must be NULL-terminated and have a length of 1024 bytes or less including - //! the NULL terminator. - //! - virtual AsciiChar const* getPluginName() const noexcept = 0; - - //! - //! \brief Return the plugin version. - //! - //! \warning The string returned must be NULL-terminated and have a length of 1024 bytes or less including - //! the NULL terminator. - //! - virtual AsciiChar const* getPluginVersion() const noexcept = 0; - - //! - //! \brief Return the plugin namespace. - //! - //! \warning The string returned must be NULL-terminated and have a length of 1024 bytes or less including - //! the NULL terminator. - //! - virtual AsciiChar const* getPluginNamespace() const noexcept = 0; - - IPluginCreatorV3One() = default; - virtual ~IPluginCreatorV3One() = default; - -protected: - IPluginCreatorV3One(IPluginCreatorV3One const&) = default; - IPluginCreatorV3One(IPluginCreatorV3One&&) = default; - IPluginCreatorV3One& operator=(IPluginCreatorV3One const&) & = default; - IPluginCreatorV3One& operator=(IPluginCreatorV3One&&) & = default; -}; -} // namespace v_1_0 - -//! -//! \class IPluginCreatorV3One -//! -//! \brief A plugin creator class capable of producing IPluginV3 objects -//! -//! \see IPluginV3 -//! \see IPluginRegistry -//! -using IPluginCreatorV3One = v_1_0::IPluginCreatorV3One; - namespace v_1_0 { class IProfiler @@ -1375,6 +1272,464 @@ constexpr inline int32_t EnumMax() noexcept //! IRuntime::getTempfileControlFlags() using TempfileControlFlags = uint32_t; +//! +//! \enum TensorFormat +//! +//! \brief Format of the input/output tensors. +//! +//! This enum is used by both plugins and network I/O tensors. +//! +//! \see IPluginV2::supportsFormat(), safe::ICudaEngine::getBindingFormat() +//! +//! Many of the formats are **vector-major** or **vector-minor**. These formats specify +//! a vector dimension and scalars per vector. +//! For example, suppose that the tensor has has dimensions [M,N,C,H,W], +//! the vector dimension is C and there are V scalars per vector. +//! +//! * A **vector-major** format splits the vectorized dimension into two axes in the +//! memory layout. The vectorized dimension is replaced by an axis of length ceil(C/V) +//! and a new dimension of length V is appended. For the example tensor, the memory layout +//! is equivalent to an array with dimensions [M][N][ceil(C/V)][H][W][V]. +//! Tensor coordinate (m,n,c,h,w) maps to array location [m][n][c/V][h][w][c\%V]. +//! +//! * A **vector-minor** format moves the vectorized dimension to become the last axis +//! in the memory layout. For the example tensor, the memory layout is equivalent to an +//! array with dimensions [M][N][H][W][ceil(C/V)*V]. Tensor coordinate (m,n,c,h,w) maps +//! array location subscript [m][n][h][w][c]. +//! +//! In interfaces that refer to "components per element", that's the value of V above. +//! +//! For more information about data formats, see the topic "Data Format Description" located in the +//! TensorRT Developer Guide. https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#data-format-desc +//! +enum class TensorFormat : int32_t +{ + //! Memory layout is similar to an array in C or C++. + //! The stride of each dimension is the product of the dimensions after it. + //! The last dimension has unit stride. + //! + //! For DLA usage, the tensor sizes are limited to C,H,W in the range [1,8192]. + kLINEAR = 0, + + //! Vector-major format with two scalars per vector. + //! Vector dimension is third to last. + //! + //! This format requires FP16 and at least three dimensions. + kCHW2 = 1, + + //! Vector-minor format with eight scalars per vector. + //! Vector dimension is third to last. + //! This format requires FP16 or BF16 and at least three dimensions. + kHWC8 = 2, + + //! Vector-major format with four scalars per vector. + //! Vector dimension is third to last. + //! + //! This format requires INT8 or FP16 and at least three dimensions. + //! For INT8, the length of the vector dimension must be a build-time constant. + //! + //! Deprecated usage: + //! + //! If running on the DLA, this format can be used for acceleration + //! with the caveat that C must be less than or equal to 4. + //! If used as DLA input and the build option kGPU_FALLBACK is not specified, + //! it needs to meet line stride requirement of DLA format. Column stride in + //! bytes must be a multiple of 64 on Orin. + kCHW4 = 3, + + //! Vector-major format with 16 scalars per vector. + //! Vector dimension is third to last. + //! + //! This format requires FP16 and at least three dimensions. + //! + //! For DLA usage, this format maps to the native feature format for FP16, + //! and the tensor sizes are limited to C,H,W in the range [1,8192]. + kCHW16 = 4, + + //! Vector-major format with 32 scalars per vector. + //! Vector dimension is third to last. + //! + //! This format requires at least three dimensions. + //! + //! For DLA usage, this format maps to the native feature format for INT8, + //! and the tensor sizes are limited to C,H,W in the range [1,8192]. + kCHW32 = 5, + + //! Vector-minor format with eight scalars per vector. + //! Vector dimension is fourth to last. + //! + //! This format requires FP16 or BF16 and at least four dimensions. + kDHWC8 = 6, + + //! Vector-major format with 32 scalars per vector. + //! Vector dimension is fourth to last. + //! + //! This format requires FP16 or INT8 and at least four dimensions. + kCDHW32 = 7, + + //! Vector-minor format where channel dimension is third to last and unpadded. + //! + //! This format requires either FP32, FP16, UINT8, INT64 or BF16 and at least three dimensions. + kHWC = 8, + + //! DLA planar format. For a tensor with dimension {N, C, H, W}, the W axis + //! always has unit stride. The stride for stepping along the H axis is + //! rounded up to 64 bytes. + //! + //! The memory layout is equivalent to a C array with dimensions + //! [N][C][H][roundUp(W, 64/elementSize)] where elementSize is + //! 2 for FP16 and 1 for Int8, with the tensor coordinates (n, c, h, w) + //! mapping to array subscript [n][c][h][w]. + kDLA_LINEAR = 9, + + //! DLA image format. For a tensor with dimension {N, C, H, W} the C axis + //! always has unit stride. The stride for stepping along the H axis is rounded up + //! to 64 bytes on Orin. C can only be 1, 3 or 4. + //! If C == 1, it will map to grayscale format. + //! If C == 3 or C == 4, it will map to color image format. And if C == 3, + //! the stride for stepping along the W axis needs to be padded to 4 in elements. + //! + //! When C is {1, 3, 4}, then C' is {1, 4, 4} respectively, + //! the memory layout is equivalent to a C array with dimensions + //! [N][H][roundUp(W, 64/C'/elementSize)][C'] on Orin + //! where elementSize is 2 for FP16 + //! and 1 for Int8. The tensor coordinates (n, c, h, w) mapping to array + //! subscript [n][h][w][c]. + kDLA_HWC4 = 10, + + //! Vector-minor format with 16 scalars per vector. + //! Vector dimension is third to last. + //! + //! This requires FP16 or INT8 and at least three dimensions. + kHWC16 = 11, + + //! Vector-minor format with one scalar per vector. + //! Vector dimension is fourth to last. + //! + //! This format requires FP32 and at least four dimensions. + kDHWC = 12 +}; + +namespace impl +{ +//! Maximum number of elements in TensorFormat enum. \see TensorFormat +template <> +struct EnumMaxImpl +{ + //! Declaration of kVALUE that represents the maximum number of elements in the TensorFormat enum. + static constexpr int32_t kVALUE = 13; +}; +} // namespace impl + +//! +//! \enum AllocatorFlag +//! +//! \brief Allowed type of memory allocation. +//! +enum class AllocatorFlag : int32_t +{ + //! TensorRT may call realloc() on this allocation. + kRESIZABLE = 0, +}; + +namespace impl +{ +//! Maximum number of elements in AllocatorFlag enum. \see AllocatorFlag +template <> +struct EnumMaxImpl +{ + //! Declaration of kVALUE that represents the maximum number of elements in the AllocatorFlag enum. + static constexpr int32_t kVALUE = 1; +}; +} // namespace impl + +using AllocatorFlags = uint32_t; + +//! DO NOT REFER TO namespace v_1_0 IN CODE. ALWAYS USE nvinfer1 INSTEAD. +//! The name v_1_0 may change in future versions of TensoRT. + +//! +//! \class ILogger +//! +//! \brief Application-implemented logging interface for the builder, refitter and runtime. +//! +//! The logger used to create an instance of IBuilder, IRuntime or IRefitter is used for all objects created through +//! that interface. The logger must be valid until all objects created are released. +//! +//! The Logger object implementation must be thread safe. All locking and synchronization is pushed to the +//! interface implementation and TensorRT does not hold any synchronization primitives when calling the interface +//! functions. +//! +class ILogger +{ +public: + //! + //! \enum Severity + //! + //! \brief The severity corresponding to a log message. + //! + enum class Severity : int32_t + { + //! An internal error has occurred. Execution is unrecoverable. + kINTERNAL_ERROR = 0, + //! An application error has occurred. + kERROR = 1, + //! An application error has been discovered, but TensorRT has recovered or fallen back to a default. + kWARNING = 2, + //! Informational messages with instructional information. + kINFO = 3, + //! Verbose messages with debugging information. + kVERBOSE = 4, + }; + + //! + //! \brief A callback implemented by the application to handle logging messages; + //! + //! \param severity The severity of the message. + //! \param msg A null-terminated log message. + //! + //! \warning Loggers used in the safety certified runtime must set a maximum message length and truncate + //! messages exceeding this length. It is up to the implementer of the derived class to define + //! a suitable limit that will prevent buffer overruns, resource exhaustion, and other security + //! vulnerabilities in their implementation. The TensorRT safety certified runtime will never + //! emit messages longer than 1024 bytes. + //! + //! \usage + //! - Allowed context for the API call + //! - Thread-safe: Yes, this method is required to be thread-safe and may be called from multiple threads + //! when multiple execution contexts are used during runtime, or if the same logger is used + //! for multiple runtimes, builders, or refitters. + //! + virtual void log(Severity severity, AsciiChar const* msg) noexcept = 0; + + ILogger() = default; + virtual ~ILogger() = default; + +protected: + // @cond SuppressDoxyWarnings + ILogger(ILogger const&) = default; + ILogger(ILogger&&) = default; + ILogger& operator=(ILogger const&) & = default; + ILogger& operator=(ILogger&&) & = default; + // @endcond +}; + +namespace impl +{ +//! Maximum number of elements in ILogger::Severity enum. \see ILogger::Severity +template <> +struct EnumMaxImpl +{ + //! Declaration of kVALUE that represents the maximum number of elements in the ILogger::Severity enum. + static constexpr int32_t kVALUE = 5; +}; +} // namespace impl + +namespace v_1_0 +{ + +class IGpuAllocator : public IVersionedInterface +{ +public: + //! + //! \brief A thread-safe callback implemented by the application to handle acquisition of GPU memory. + //! + //! \param size The size of the memory block required (in bytes). + //! \param alignment The required alignment of memory. Alignment will be zero + //! or a power of 2 not exceeding the alignment guaranteed by cudaMalloc. + //! Thus this allocator can be safely implemented with cudaMalloc/cudaFree. + //! An alignment value of zero indicates any alignment is acceptable. + //! \param flags Reserved for future use. In the current release, 0 will be passed. + //! + //! \return If the allocation was successful, the start address of a device memory block of the requested size. + //! If an allocation request of size 0 is made, nullptr must be returned. + //! If an allocation request cannot be satisfied, nullptr must be returned. + //! If a non-null address is returned, it is guaranteed to have the specified alignment. + //! + //! \note The implementation must guarantee thread safety for concurrent allocate/reallocate/deallocate + //! requests. + //! + //! \usage + //! - Allowed context for the API call + //! - Thread-safe: Yes, this method is required to be thread-safe and may be called from multiple threads. + //! + //! \deprecated Deprecated in TensorRT 10.0. Superseded by allocateAsync + //! + TRT_DEPRECATED virtual void* allocate( + uint64_t const size, uint64_t const alignment, AllocatorFlags const flags) noexcept = 0; + + ~IGpuAllocator() override = default; + IGpuAllocator() = default; + + //! + //! \brief A thread-safe callback implemented by the application to resize an existing allocation. + //! + //! Only allocations which were allocated with AllocatorFlag::kRESIZABLE will be resized. + //! + //! Options are one of: + //! * resize in place leaving min(oldSize, newSize) bytes unchanged and return the original address + //! * move min(oldSize, newSize) bytes to a new location of sufficient size and return its address + //! * return nullptr, to indicate that the request could not be fulfilled. + //! + //! If nullptr is returned, TensorRT will assume that resize() is not implemented, and that the + //! allocation at baseAddr is still valid. + //! + //! This method is made available for use cases where delegating the resize + //! strategy to the application provides an opportunity to improve memory management. + //! One possible implementation is to allocate a large virtual device buffer and + //! progressively commit physical memory with cuMemMap. CU_MEM_ALLOC_GRANULARITY_RECOMMENDED + //! is suggested in this case. + //! + //! TensorRT may call realloc to increase the buffer by relatively small amounts. + //! + //! \param baseAddr the address of the original allocation, which will have been returned by previously calling + //! allocate() or reallocate() on the same object. + //! \param alignment The alignment used by the original allocation. This will be the same value that was previously + //! passed to the allocate() or reallocate() call that returned baseAddr. + //! \param newSize The new memory size required (in bytes). + //! + //! \return The address of the reallocated memory, or nullptr. If a non-null address is returned, it is + //! guaranteed to have the specified alignment. + //! + //! \note The implementation must guarantee thread safety for concurrent allocate/reallocate/deallocate + //! requests. + //! + //! \usage + //! - Allowed context for the API call + //! - Thread-safe: Yes, this method is required to be thread-safe and may be called from multiple threads. + //! + virtual void* reallocate(void* const /*baseAddr*/, uint64_t /*alignment*/, uint64_t /*newSize*/) noexcept + { + return nullptr; + } + + //! + //! \brief A thread-safe callback implemented by the application to handle release of GPU memory. + //! + //! TensorRT may pass a nullptr to this function if it was previously returned by allocate(). + //! + //! \param memory A memory address that was previously returned by an allocate() or reallocate() call of the same + //! allocator object. + //! + //! \return True if the acquired memory is released successfully. + //! + //! \note The implementation must guarantee thread safety for concurrent allocate/reallocate/deallocate + //! requests. + //! + //! \usage + //! - Allowed context for the API call + //! - Thread-safe: Yes, this method is required to be thread-safe and may be called from multiple threads. + //! \deprecated Deprecated in TensorRT 10.0. Superseded by deallocateAsync + //! + TRT_DEPRECATED virtual bool deallocate(void* const memory) noexcept = 0; + + //! + //! \brief A thread-safe callback implemented by the application to handle stream-ordered acquisition of GPU memory. + //! + //! The default behavior is to call method allocate(), which is synchronous and thus loses + //! any performance benefits of asynchronous allocation. If you want the benefits of asynchronous + //! allocation, see discussion of IGpuAsyncAllocator vs. IGpuAllocator in the documentation + //! for nvinfer1::IGpuAllocator. + //! + //! \param size The size of the memory block required (in bytes). + //! \param alignment The required alignment of memory. Alignment will be zero + //! or a power of 2 not exceeding the alignment guaranteed by cudaMalloc. + //! Thus this allocator can be safely implemented with cudaMalloc/cudaFree. + //! An alignment value of zero indicates any alignment is acceptable. + //! \param flags Reserved for future use. In the current release, 0 will be passed. + //! \param stream specifies the cudaStream for asynchronous usage. + //! + //! \return If the allocation was successful, the start address of a device memory block of the requested size. + //! If an allocation request of size 0 is made, nullptr must be returned. + //! If an allocation request cannot be satisfied, nullptr must be returned. + //! If a non-null address is returned, it is guaranteed to have the specified alignment. + //! + //! \note The implementation must guarantee thread safety for concurrent allocate/reallocate/deallocate + //! requests. + //! + //! \usage + //! - Allowed context for the API call + //! - Thread-safe: Yes, this method is required to be thread-safe and may be called from multiple threads. + //! + virtual void* allocateAsync( + uint64_t const size, uint64_t const alignment, AllocatorFlags const flags, cudaStream_t /*stream*/) noexcept + { + return allocate(size, alignment, flags); + } + //! + //! \brief A thread-safe callback implemented by the application to handle stream-ordered release of GPU memory. + //! + //! The default behavior is to call method deallocate(), which is synchronous and thus loses + //! any performance benefits of asynchronous deallocation. If you want the benefits of asynchronous + //! deallocation, see discussion of IGpuAsyncAllocator vs. IGpuAllocator in the documentation + //! for nvinfer1::IGpuAllocator. + //! + //! TensorRT may pass a nullptr to this function if it was previously returned by allocate(). + //! + //! \param memory A memory address that was previously returned by an allocate() or reallocate() call of the same + //! allocator object. + //! \param stream specifies the cudaStream for asynchronous usage. + //! + //! \return True if the acquired memory is released successfully. + //! + //! \note The implementation must guarantee thread safety for concurrent allocate/reallocate/deallocate + //! requests. + //! + //! \note The implementation is not required to be asynchronous. It is permitted to synchronize, + //! albeit doing so will lose the performance advantage of asynchronous deallocation. + //! Either way, it is critical that it not actually free the memory until the current + //! stream position is reached. + //! + //! \usage + //! - Allowed context for the API call + //! - Thread-safe: Yes, this method is required to be thread-safe and may be called from multiple threads. + //! + virtual bool deallocateAsync(void* const memory, cudaStream_t /*stream*/) noexcept + { + return deallocate(memory); + } + + //! + //! \brief Return version information associated with this interface. Applications must not override this method. + //! + InterfaceInfo getInterfaceInfo() const noexcept override + { + return {"IGpuAllocator", 1, 0}; + } + +protected: + // @cond SuppressDoxyWarnings + IGpuAllocator(IGpuAllocator const&) = default; + IGpuAllocator(IGpuAllocator&&) = default; + IGpuAllocator& operator=(IGpuAllocator const&) & = default; + IGpuAllocator& operator=(IGpuAllocator&&) & = default; + // @endcond +}; + +} // namespace v_1_0 + +//! +//! \class IGpuAllocator +//! +//! \brief Application-implemented class for controlling allocation on the GPU. +//! +//! \warning The lifetime of an IGpuAllocator object must exceed that of all objects that use it. +//! +//! This class is intended as a base class for allocators that implement synchronous allocation. +//! If you want the benefits of asynchronous allocation, you can do either of: +//! +//! * Derive your class from IGpuAllocator and override all four of its virtual methods +//! for allocation/deallocation, including the two deprecated methods. +//! +//! * Derive your class from IGpuAsyncAllocator and override its two pure virtual +//! methods for allocation/deallocation. +//! +//! The latter style is preferred because it does not tie code to deprecated methods. +//! +//! \see IGpuAsyncAllocator. +//! +using IGpuAllocator = v_1_0::IGpuAllocator; + //! //! \class IRuntime //! @@ -1503,6 +1858,7 @@ class IRuntime : public INoCopy return mImpl->deserializeCudaEngine(streamReader); } + //! //! \brief get the logger with which the runtime was created //! @@ -3533,7 +3889,7 @@ class IDebugListener : public IVersionedInterface //! \param type data Type of the tensor. //! \param shape shape of the tensor. //! \param name name of the tensor. - //! \param stream Cuda stream object. + //! \param stream CUDA stream object. //! //! \return True on success, false otherwise. //! @@ -3647,7 +4003,7 @@ class IExecutionContext : public INoCopy //! //! \brief Set the device memory for use by this execution context. //! - //! The memory must be aligned with cuda memory alignment property (using cudaGetDeviceProperties()), and its size + //! The memory must be aligned with CUDA memory alignment property (using cudaGetDeviceProperties()), and its size //! must be large enough for performing inference with the given network inputs. getDeviceMemorySize() and //! getDeviceMemorySizeForProfile() report upper bounds of the size. Setting memory to nullptr is acceptable if the //! reported size is 0. If using enqueueV3() to run the network, the memory is in use from the invocation of @@ -3674,7 +4030,7 @@ class IExecutionContext : public INoCopy //! //! \brief Set the device memory and its corresponding size for use by this execution context. //! - //! The memory must be aligned with cuda memory alignment property (using cudaGetDeviceProperties()), and its size + //! The memory must be aligned with CUDA memory alignment property (using cudaGetDeviceProperties()), and its size //! must be large enough for performing inference with the given network inputs. getDeviceMemorySize() and //! getDeviceMemorySizeForProfile() report upper bounds of the size. Setting memory to nullptr is acceptable if the //! reported size is 0. If using enqueueV3() to run the network, the memory is in use from the invocation of @@ -3875,7 +4231,7 @@ class IExecutionContext : public INoCopy //! \param profileIndex Index of the profile. The value must lie between 0 and //! getEngine().getNbOptimizationProfiles() - 1 //! - //! \param stream A cuda stream on which the cudaMemcpyAsyncs may be + //! \param stream A CUDA stream on which the cudaMemcpyAsyncs may be //! enqueued //! //! When an optimization profile is switched via this API, TensorRT may @@ -4145,7 +4501,7 @@ class IExecutionContext : public INoCopy //! //! \brief Mark input as consumed. //! - //! \param event The cuda event that is triggered after all input tensors have been consumed. + //! \param event The CUDA event that is triggered after all input tensors have been consumed. //! //! \warning The set event must be valid during the inferece. //! @@ -4161,7 +4517,7 @@ class IExecutionContext : public INoCopy //! //! \brief The event associated with consuming the input. //! - //! \return The cuda event. Nullptr will be returned if the event is not set yet. + //! \return The CUDA event. Nullptr will be returned if the event is not set yet. //! cudaEvent_t getInputConsumedEvent() const noexcept { @@ -4251,7 +4607,7 @@ class IExecutionContext : public INoCopy //! //! \brief Enqueue inference on a stream. //! - //! \param stream A cuda stream on which the inference kernels will be enqueued. + //! \param stream A CUDA stream on which the inference kernels will be enqueued. //! //! \return True if the kernels were enqueued successfully, false otherwise. //! @@ -4844,7 +5200,6 @@ class IGpuAsyncAllocator : public IGpuAllocator //! //! \see IGpuAllocator using IGpuAsyncAllocator = v_1_0::IGpuAsyncAllocator; - } // namespace nvinfer1 //! diff --git a/include/NvInferRuntimeBase.h b/include/NvInferRuntimeBase.h index b6652c07e..bde3d1dd2 100644 --- a/include/NvInferRuntimeBase.h +++ b/include/NvInferRuntimeBase.h @@ -66,12 +66,10 @@ //! //! \warning Do not directly include this file. Instead include one of: //! * NvInferRuntime.h (for the standard runtime) -//! * NvInferSafeRuntime.h (for the safety runtime) -//! * NvInferConsistency.h (for consistency checker) //! * NvInferPluginUtils.h (for plugin utilities) //! -#if !defined(NV_INFER_INTERNAL_INCLUDE_RUNTIME_BASE) -static_assert(false, "Do not directly include this file. Include NvInferRuntime.h or NvInferSafeRuntime.h or NvInferConsistency.h or NvInferPluginUtils.h"); +#if !defined(NV_INFER_INTERNAL_INCLUDE) +static_assert(false, "Do not directly include this file. Include NvInferRuntime.h or NvInferPluginUtils.h"); #endif //! Forward declare some CUDA types to avoid an include dependency. @@ -216,144 +214,6 @@ class Dims64 //! using Dims = Dims64; -//! -//! \enum TensorFormat -//! -//! \brief Format of the input/output tensors. -//! -//! This enum is used by both plugins and network I/O tensors. -//! -//! \see IPluginV2::supportsFormat(), safe::ICudaEngine::getBindingFormat() -//! -//! Many of the formats are **vector-major** or **vector-minor**. These formats specify -//! a vector dimension and scalars per vector. -//! For example, suppose that the tensor has has dimensions [M,N,C,H,W], -//! the vector dimension is C and there are V scalars per vector. -//! -//! * A **vector-major** format splits the vectorized dimension into two axes in the -//! memory layout. The vectorized dimension is replaced by an axis of length ceil(C/V) -//! and a new dimension of length V is appended. For the example tensor, the memory layout -//! is equivalent to an array with dimensions [M][N][ceil(C/V)][H][W][V]. -//! Tensor coordinate (m,n,c,h,w) maps to array location [m][n][c/V][h][w][c\%V]. -//! -//! * A **vector-minor** format moves the vectorized dimension to become the last axis -//! in the memory layout. For the example tensor, the memory layout is equivalent to an -//! array with dimensions [M][N][H][W][ceil(C/V)*V]. Tensor coordinate (m,n,c,h,w) maps -//! array location subscript [m][n][h][w][c]. -//! -//! In interfaces that refer to "components per element", that's the value of V above. -//! -//! For more information about data formats, see the topic "Data Format Description" located in the -//! TensorRT Developer Guide. https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#data-format-desc -//! -enum class TensorFormat : int32_t -{ - //! Memory layout is similar to an array in C or C++. - //! The stride of each dimension is the product of the dimensions after it. - //! The last dimension has unit stride. - //! - //! For DLA usage, the tensor sizes are limited to C,H,W in the range [1,8192]. - kLINEAR = 0, - - //! Vector-major format with two scalars per vector. - //! Vector dimension is third to last. - //! - //! This format requires FP16 and at least three dimensions. - kCHW2 = 1, - - //! Vector-minor format with eight scalars per vector. - //! Vector dimension is third to last. - //! This format requires FP16 or BF16 and at least three dimensions. - kHWC8 = 2, - - //! Vector-major format with four scalars per vector. - //! Vector dimension is third to last. - //! - //! This format requires INT8 or FP16 and at least three dimensions. - //! For INT8, the length of the vector dimension must be a build-time constant. - //! - //! Deprecated usage: - //! - //! If running on the DLA, this format can be used for acceleration - //! with the caveat that C must be less than or equal to 4. - //! If used as DLA input and the build option kGPU_FALLBACK is not specified, - //! it needs to meet line stride requirement of DLA format. Column stride in - //! bytes must be a multiple of 64 on Orin. - kCHW4 = 3, - - //! Vector-major format with 16 scalars per vector. - //! Vector dimension is third to last. - //! - //! This format requires FP16 and at least three dimensions. - //! - //! For DLA usage, this format maps to the native feature format for FP16, - //! and the tensor sizes are limited to C,H,W in the range [1,8192]. - kCHW16 = 4, - - //! Vector-major format with 32 scalars per vector. - //! Vector dimension is third to last. - //! - //! This format requires at least three dimensions. - //! - //! For DLA usage, this format maps to the native feature format for INT8, - //! and the tensor sizes are limited to C,H,W in the range [1,8192]. - kCHW32 = 5, - - //! Vector-minor format with eight scalars per vector. - //! Vector dimension is fourth to last. - //! - //! This format requires FP16 or BF16 and at least four dimensions. - kDHWC8 = 6, - - //! Vector-major format with 32 scalars per vector. - //! Vector dimension is fourth to last. - //! - //! This format requires FP16 or INT8 and at least four dimensions. - kCDHW32 = 7, - - //! Vector-minor format where channel dimension is third to last and unpadded. - //! - //! This format requires either FP32, FP16, UINT8, INT64 or BF16 and at least three dimensions. - kHWC = 8, - - //! DLA planar format. For a tensor with dimension {N, C, H, W}, the W axis - //! always has unit stride. The stride for stepping along the H axis is - //! rounded up to 64 bytes. - //! - //! The memory layout is equivalent to a C array with dimensions - //! [N][C][H][roundUp(W, 64/elementSize)] where elementSize is - //! 2 for FP16 and 1 for Int8, with the tensor coordinates (n, c, h, w) - //! mapping to array subscript [n][c][h][w]. - kDLA_LINEAR = 9, - - //! DLA image format. For a tensor with dimension {N, C, H, W} the C axis - //! always has unit stride. The stride for stepping along the H axis is rounded up - //! to 64 bytes on Orin. C can only be 1, 3 or 4. - //! If C == 1, it will map to grayscale format. - //! If C == 3 or C == 4, it will map to color image format. And if C == 3, - //! the stride for stepping along the W axis needs to be padded to 4 in elements. - //! - //! When C is {1, 3, 4}, then C' is {1, 4, 4} respectively, - //! the memory layout is equivalent to a C array with dimensions - //! [N][H][roundUp(W, 64/C'/elementSize)][C'] on Orin - //! where elementSize is 2 for FP16 - //! and 1 for Int8. The tensor coordinates (n, c, h, w) mapping to array - //! subscript [n][h][w][c]. - kDLA_HWC4 = 10, - - //! Vector-minor format with 16 scalars per vector. - //! Vector dimension is third to last. - //! - //! This requires FP16 or INT8 and at least three dimensions. - kHWC16 = 11, - - //! Vector-minor format with one scalar per vector. - //! Vector dimension is fourth to last. - //! - //! This format requires FP32 and at least four dimensions. - kDHWC = 12 -}; - using InterfaceKind = char const*; //! @@ -424,326 +284,6 @@ class IVersionedInterface IVersionedInterface& operator=(IVersionedInterface&&) & = default; }; -namespace impl -{ -//! Maximum number of elements in TensorFormat enum. \see TensorFormat -template <> -struct EnumMaxImpl -{ - //! Declaration of kVALUE that represents the maximum number of elements in the TensorFormat enum. - static constexpr int32_t kVALUE = 13; -}; -} // namespace impl - - -//! -//! \enum AllocatorFlag -//! -//! \brief Allowed type of memory allocation. -//! -enum class AllocatorFlag : int32_t -{ - //! TensorRT may call realloc() on this allocation. - kRESIZABLE = 0, -}; - -namespace impl -{ -//! Maximum number of elements in AllocatorFlag enum. \see AllocatorFlag -template <> -struct EnumMaxImpl -{ - //! Declaration of kVALUE that represents the maximum number of elements in the AllocatorFlag enum. - static constexpr int32_t kVALUE = 1; -}; -} // namespace impl - -using AllocatorFlags = uint32_t; - -//! DO NOT REFER TO namespace v_1_0 IN CODE. ALWAYS USE nvinfer1 INSTEAD. -//! The name v_1_0 may change in future versions of TensoRT. -namespace v_1_0 -{ - -class IGpuAllocator : public IVersionedInterface -{ -public: - //! - //! \brief A thread-safe callback implemented by the application to handle acquisition of GPU memory. - //! - //! \param size The size of the memory block required (in bytes). - //! \param alignment The required alignment of memory. Alignment will be zero - //! or a power of 2 not exceeding the alignment guaranteed by cudaMalloc. - //! Thus this allocator can be safely implemented with cudaMalloc/cudaFree. - //! An alignment value of zero indicates any alignment is acceptable. - //! \param flags Reserved for future use. In the current release, 0 will be passed. - //! - //! \return If the allocation was successful, the start address of a device memory block of the requested size. - //! If an allocation request of size 0 is made, nullptr must be returned. - //! If an allocation request cannot be satisfied, nullptr must be returned. - //! If a non-null address is returned, it is guaranteed to have the specified alignment. - //! - //! \note The implementation must guarantee thread safety for concurrent allocate/reallocate/deallocate - //! requests. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes, this method is required to be thread-safe and may be called from multiple threads. - //! - //! \deprecated Deprecated in TensorRT 10.0. Superseded by allocateAsync - //! - TRT_DEPRECATED virtual void* allocate( - uint64_t const size, uint64_t const alignment, AllocatorFlags const flags) noexcept = 0; - - ~IGpuAllocator() override = default; - IGpuAllocator() = default; - - //! - //! \brief A thread-safe callback implemented by the application to resize an existing allocation. - //! - //! Only allocations which were allocated with AllocatorFlag::kRESIZABLE will be resized. - //! - //! Options are one of: - //! * resize in place leaving min(oldSize, newSize) bytes unchanged and return the original address - //! * move min(oldSize, newSize) bytes to a new location of sufficient size and return its address - //! * return nullptr, to indicate that the request could not be fulfilled. - //! - //! If nullptr is returned, TensorRT will assume that resize() is not implemented, and that the - //! allocation at baseAddr is still valid. - //! - //! This method is made available for use cases where delegating the resize - //! strategy to the application provides an opportunity to improve memory management. - //! One possible implementation is to allocate a large virtual device buffer and - //! progressively commit physical memory with cuMemMap. CU_MEM_ALLOC_GRANULARITY_RECOMMENDED - //! is suggested in this case. - //! - //! TensorRT may call realloc to increase the buffer by relatively small amounts. - //! - //! \param baseAddr the address of the original allocation, which will have been returned by previously calling - //! allocate() or reallocate() on the same object. - //! \param alignment The alignment used by the original allocation. This will be the same value that was previously - //! passed to the allocate() or reallocate() call that returned baseAddr. - //! \param newSize The new memory size required (in bytes). - //! - //! \return The address of the reallocated memory, or nullptr. If a non-null address is returned, it is - //! guaranteed to have the specified alignment. - //! - //! \note The implementation must guarantee thread safety for concurrent allocate/reallocate/deallocate - //! requests. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes, this method is required to be thread-safe and may be called from multiple threads. - //! - virtual void* reallocate(void* const /*baseAddr*/, uint64_t /*alignment*/, uint64_t /*newSize*/) noexcept - { - return nullptr; - } - - //! - //! \brief A thread-safe callback implemented by the application to handle release of GPU memory. - //! - //! TensorRT may pass a nullptr to this function if it was previously returned by allocate(). - //! - //! \param memory A memory address that was previously returned by an allocate() or reallocate() call of the same - //! allocator object. - //! - //! \return True if the acquired memory is released successfully. - //! - //! \note The implementation must guarantee thread safety for concurrent allocate/reallocate/deallocate - //! requests. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes, this method is required to be thread-safe and may be called from multiple threads. - //! \deprecated Deprecated in TensorRT 10.0. Superseded by deallocateAsync - //! - TRT_DEPRECATED virtual bool deallocate(void* const memory) noexcept = 0; - - //! - //! \brief A thread-safe callback implemented by the application to handle stream-ordered acquisition of GPU memory. - //! - //! The default behavior is to call method allocate(), which is synchronous and thus loses - //! any performance benefits of asynchronous allocation. If you want the benefits of asynchronous - //! allocation, see discussion of IGpuAsyncAllocator vs. IGpuAllocator in the documentation - //! for nvinfer1::IGpuAllocator. - //! - //! \param size The size of the memory block required (in bytes). - //! \param alignment The required alignment of memory. Alignment will be zero - //! or a power of 2 not exceeding the alignment guaranteed by cudaMalloc. - //! Thus this allocator can be safely implemented with cudaMalloc/cudaFree. - //! An alignment value of zero indicates any alignment is acceptable. - //! \param flags Reserved for future use. In the current release, 0 will be passed. - //! \param stream specifies the cudaStream for asynchronous usage. - //! - //! \return If the allocation was successful, the start address of a device memory block of the requested size. - //! If an allocation request of size 0 is made, nullptr must be returned. - //! If an allocation request cannot be satisfied, nullptr must be returned. - //! If a non-null address is returned, it is guaranteed to have the specified alignment. - //! - //! \note The implementation must guarantee thread safety for concurrent allocate/reallocate/deallocate - //! requests. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes, this method is required to be thread-safe and may be called from multiple threads. - //! - virtual void* allocateAsync( - uint64_t const size, uint64_t const alignment, AllocatorFlags const flags, cudaStream_t /*stream*/) noexcept - { - return allocate(size, alignment, flags); - } - //! - //! \brief A thread-safe callback implemented by the application to handle stream-ordered release of GPU memory. - //! - //! The default behavior is to call method deallocate(), which is synchronous and thus loses - //! any performance benefits of asynchronous deallocation. If you want the benefits of asynchronous - //! deallocation, see discussion of IGpuAsyncAllocator vs. IGpuAllocator in the documentation - //! for nvinfer1::IGpuAllocator. - //! - //! TensorRT may pass a nullptr to this function if it was previously returned by allocate(). - //! - //! \param memory A memory address that was previously returned by an allocate() or reallocate() call of the same - //! allocator object. - //! \param stream specifies the cudaStream for asynchronous usage. - //! - //! \return True if the acquired memory is released successfully. - //! - //! \note The implementation must guarantee thread safety for concurrent allocate/reallocate/deallocate - //! requests. - //! - //! \note The implementation is not required to be asynchronous. It is permitted to synchronize, - //! albeit doing so will lose the performance advantage of asynchronous deallocation. - //! Either way, it is critical that it not actually free the memory until the current - //! stream position is reached. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes, this method is required to be thread-safe and may be called from multiple threads. - //! - virtual bool deallocateAsync(void* const memory, cudaStream_t /*stream*/) noexcept - { - return deallocate(memory); - } - - //! - //! \brief Return version information associated with this interface. Applications must not override this method. - //! - InterfaceInfo getInterfaceInfo() const noexcept override - { - return {"IGpuAllocator", 1, 0}; - } - -protected: - // @cond SuppressDoxyWarnings - IGpuAllocator(IGpuAllocator const&) = default; - IGpuAllocator(IGpuAllocator&&) = default; - IGpuAllocator& operator=(IGpuAllocator const&) & = default; - IGpuAllocator& operator=(IGpuAllocator&&) & = default; - // @endcond -}; - -} // namespace v_1_0 - -//! -//! \class IGpuAllocator -//! -//! \brief Application-implemented class for controlling allocation on the GPU. -//! -//! \warning The lifetime of an IGpuAllocator object must exceed that of all objects that use it. -//! -//! This class is intended as a base class for allocators that implement synchronous allocation. -//! If you want the benefits of asynchronous allocation, you can do either of: -//! -//! * Derive your class from IGpuAllocator and override all four of its virtual methods -//! for allocation/deallocation, including the two deprecated methods. -//! -//! * Derive your class from IGpuAsyncAllocator and override its two pure virtual -//! methods for allocation/deallocation. -//! -//! The latter style is preferred because it does not tie code to deprecated methods. -//! -//! \see IGpuAsyncAllocator. -//! -using IGpuAllocator = v_1_0::IGpuAllocator; - -//! -//! \class ILogger -//! -//! \brief Application-implemented logging interface for the builder, refitter and runtime. -//! -//! The logger used to create an instance of IBuilder, IRuntime or IRefitter is used for all objects created through -//! that interface. The logger must be valid until all objects created are released. -//! -//! The Logger object implementation must be thread safe. All locking and synchronization is pushed to the -//! interface implementation and TensorRT does not hold any synchronization primitives when calling the interface -//! functions. -//! -class ILogger -{ -public: - //! - //! \enum Severity - //! - //! \brief The severity corresponding to a log message. - //! - enum class Severity : int32_t - { - //! An internal error has occurred. Execution is unrecoverable. - kINTERNAL_ERROR = 0, - //! An application error has occurred. - kERROR = 1, - //! An application error has been discovered, but TensorRT has recovered or fallen back to a default. - kWARNING = 2, - //! Informational messages with instructional information. - kINFO = 3, - //! Verbose messages with debugging information. - kVERBOSE = 4, - }; - - //! - //! \brief A callback implemented by the application to handle logging messages; - //! - //! \param severity The severity of the message. - //! \param msg A null-terminated log message. - //! - //! \warning Loggers used in the safety certified runtime must set a maximum message length and truncate - //! messages exceeding this length. It is up to the implementer of the derived class to define - //! a suitable limit that will prevent buffer overruns, resource exhaustion, and other security - //! vulnerabilities in their implementation. The TensorRT safety certified runtime will never - //! emit messages longer than 1024 bytes. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes, this method is required to be thread-safe and may be called from multiple threads - //! when multiple execution contexts are used during runtime, or if the same logger is used - //! for multiple runtimes, builders, or refitters. - //! - virtual void log(Severity severity, AsciiChar const* msg) noexcept = 0; - - ILogger() = default; - virtual ~ILogger() = default; - -protected: -// @cond SuppressDoxyWarnings - ILogger(ILogger const&) = default; - ILogger(ILogger&&) = default; - ILogger& operator=(ILogger const&) & = default; - ILogger& operator=(ILogger&&) & = default; -// @endcond -}; - -namespace impl -{ -//! Maximum number of elements in ILogger::Severity enum. \see ILogger::Severity -template <> -struct EnumMaxImpl -{ - //! Declaration of kVALUE that represents the maximum number of elements in the ILogger::Severity enum. - static constexpr int32_t kVALUE = 5; -}; -} // namespace impl - //! //! \enum ErrorCode //! @@ -1108,116 +648,6 @@ enum class TensorIOMode : int32_t kOUTPUT = 2 }; -namespace v_1_0 -{ -class IStreamReader : public IVersionedInterface -{ -public: - //! - //! TensorRT never calls the destructor for an IStreamReader defined by the - //! application. - //! - ~IStreamReader() override = default; - IStreamReader() = default; - - //! - //! \brief Return version information associated with this interface. Applications must not override this method. - //! - InterfaceInfo getInterfaceInfo() const noexcept override - { - return InterfaceInfo{"IStreamReader", 1, 0}; - } - - //! - //! \brief Read the next number of bytes in the stream. - //! - //! \param destination The memory to write to - //! \param nbBytes The number of bytes to read - //! - //! \returns The number of bytes read. Negative values will be considered an automatic error. - //! - virtual int64_t read(void* destination, int64_t nbBytes) = 0; - -protected: - IStreamReader(IStreamReader const&) = default; - IStreamReader(IStreamReader&&) = default; - IStreamReader& operator=(IStreamReader const&) & = default; - IStreamReader& operator=(IStreamReader&&) & = default; -}; -} // namespace v_1_0 - -//! -//! \class IStreamReader -//! -//! \brief Application-implemented class for reading data in a stream-based manner. -//! -//! \note To ensure compatibility of source code with future versions of TensorRT, use IStreamReader, not -//! v_1_0::IStreamReader -//! -using IStreamReader = v_1_0::IStreamReader; - -namespace v_1_0 -{ - -class IPluginResource : public IVersionedInterface -{ -public: - //! - //! \brief Return version information associated with this interface. Applications must not override this method. - //! - InterfaceInfo getInterfaceInfo() const noexcept override - { - return InterfaceInfo{"IPluginResource", 1, 0}; - } - //! - //! \brief Free the underlying resource - //! - //! This will only be called for IPluginResource objects that were produced from IPluginResource::clone() - //! - //! The IPluginResource object on which release() is called must still be in a clone-able state - //! after release() returns - //! - //! \return 0 for success, else non-zero - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: No; this method is not required to be thread-safe - //! - virtual int32_t release() noexcept = 0; - - //! - //! \brief Clone the resource object - //! - //! \note Resource initialization (if any) may be skipped for non-cloned objects since only clones will be - //! registered by TensorRT - //! - //! \return Pointer to cloned object. nullptr if there was an issue. - //! - //! \usage - //! - Allowed context for the API call - //! - Thread-safe: Yes; this method is required to be thread-safe and may be called from multiple threads. - //! - virtual IPluginResource* clone() noexcept = 0; - - ~IPluginResource() noexcept override = default; - - IPluginResource() = default; - IPluginResource(IPluginResource const&) = default; - IPluginResource(IPluginResource&&) = default; - IPluginResource& operator=(IPluginResource const&) & = default; - IPluginResource& operator=(IPluginResource&&) & = default; -}; // class IPluginResource -} // namespace v_1_0 - -//! -//! \class IPluginResource -//! -//! \brief Interface for plugins to define custom resources that could be shared through the plugin registry -//! -//! \see IPluginRegistry::acquirePluginResource -//! \see IPluginRegistry::releasePluginResource -//! -using IPluginResource = v_1_0::IPluginResource; - namespace impl { //! Maximum number of elements in TensorIOMode enum. \see TensorIOMode diff --git a/include/NvInferRuntimeCommon.h b/include/NvInferRuntimeCommon.h index 13e42f4fa..19b83b36b 100644 --- a/include/NvInferRuntimeCommon.h +++ b/include/NvInferRuntimeCommon.h @@ -28,9 +28,9 @@ //! //! \warning Do not directly include this file. Instead include NvInferRuntime.h //! -#define NV_INFER_INTERNAL_INCLUDE_RUNTIME_BASE 1 -#include "NvInferRuntimeBase.h" -#undef NV_INFER_INTERNAL_INCLUDE_RUNTIME_BASE +#define NV_INFER_INTERNAL_INCLUDE 1 +#include "NvInferPluginBase.h" +#undef NV_INFER_INTERNAL_INCLUDE #include "NvInferRuntimePlugin.h" namespace nvinfer1 diff --git a/include/NvInferRuntimePlugin.h b/include/NvInferRuntimePlugin.h index dffdd9017..dbe5bb49c 100644 --- a/include/NvInferRuntimePlugin.h +++ b/include/NvInferRuntimePlugin.h @@ -18,9 +18,9 @@ #ifndef NV_INFER_RUNTIME_PLUGIN_H #define NV_INFER_RUNTIME_PLUGIN_H -#define NV_INFER_INTERNAL_INCLUDE_RUNTIME_BASE 1 -#include "NvInferRuntimeBase.h" -#undef NV_INFER_INTERNAL_INCLUDE_RUNTIME_BASE +#define NV_INFER_INTERNAL_INCLUDE 1 +#include "NvInferPluginBase.h" +#undef NV_INFER_INTERNAL_INCLUDE //! //! \file NvInferRuntimePlugin.h @@ -28,8 +28,7 @@ //! This file contains common definitions, data structures and interfaces that relate to plugins and are shared //! between the standard and safe runtime. //! -//! \warning Do not directly include this file. Instead include either NvInferRuntime.h (for the standard runtime) or -//! NvInferSafeRuntime.h (for the safety runtime). +//! \warning Do not directly include this file. Instead include NvInferRuntime.h //! //! @@ -40,6 +39,13 @@ namespace nvinfer1 { +enum class TensorFormat : int32_t; +namespace v_1_0 +{ +class IGpuAllocator; +} +using IGpuAllocator = v_1_0::IGpuAllocator; + //! //! \brief PluginFormat is reserved for backward compatibility. //! @@ -824,126 +830,8 @@ class TRT_DEPRECATED IPluginV2IOExt : public IPluginV2Ext } }; -//! -//! \enum PluginFieldType -//! -//! \brief The possible field types for custom layer. -//! -enum class PluginFieldType : int32_t -{ - //! FP16 field type. - kFLOAT16 = 0, - //! FP32 field type. - kFLOAT32 = 1, - //! FP64 field type. - kFLOAT64 = 2, - //! INT8 field type. - kINT8 = 3, - //! INT16 field type. - kINT16 = 4, - //! INT32 field type. - kINT32 = 5, - //! char field type. - kCHAR = 6, - //! nvinfer1::Dims field type. - kDIMS = 7, - //! Unknown field type. - kUNKNOWN = 8, - //! BF16 field type. - kBF16 = 9, - //! INT64 field type. - kINT64 = 10, - //! FP8 field type. - kFP8 = 11, - //! INT4 field type. - kINT4 = 12, -}; - -//! -//! \class PluginField -//! -//! \brief Structure containing plugin attribute field names and associated data -//! This information can be parsed to decode necessary plugin metadata -//! -//! -class PluginField -{ -public: - //! Plugin field attribute name - AsciiChar const* name; - //! Plugin field attribute data - void const* data; - //! Plugin field attribute type - PluginFieldType type; - //! Number of data entries in the Plugin attribute - int32_t length; - - PluginField(AsciiChar const* const name_ = nullptr, void const* const data_ = nullptr, - PluginFieldType const type_ = PluginFieldType::kUNKNOWN, int32_t const length_ = 0) noexcept - : name(name_) - , data(data_) - , type(type_) - , length(length_) - { - } -}; - -//! -//! \struct PluginFieldCollection -//! -//! \brief Plugin field collection struct. -//! -struct PluginFieldCollection -{ - //! Number of PluginField entries. - int32_t nbFields; - //! Pointer to PluginField entries. - PluginField const* fields; -}; - -//! -//! \enum PluginCapabilityType -//! -//! \brief Enumerates the different capability types a IPluginV3 object may have -//! -enum class PluginCapabilityType : int32_t -{ - //! Core capability. Every IPluginV3 object must have this. - kCORE = 0, - //! Build capability. IPluginV3 objects provided to TensorRT build phase must have this. - kBUILD = 1, - //! Runtime capability. IPluginV3 objects provided to TensorRT build and execution phases must have this. - kRUNTIME = 2 -}; - -//! -//! \enum TensorRTPhase -//! -//! \brief Indicates a phase of operation of TensorRT -//! -enum class TensorRTPhase : int32_t -{ - //! Build phase of TensorRT - kBUILD = 0, - //! Execution phase of TensorRT - kRUNTIME = 1 -}; - namespace v_1_0 { -class IPluginCreatorInterface : public IVersionedInterface -{ -public: - ~IPluginCreatorInterface() noexcept override = default; - -protected: - IPluginCreatorInterface() = default; - IPluginCreatorInterface(IPluginCreatorInterface const&) = default; - IPluginCreatorInterface(IPluginCreatorInterface&&) = default; - IPluginCreatorInterface& operator=(IPluginCreatorInterface const&) & = default; - IPluginCreatorInterface& operator=(IPluginCreatorInterface&&) & = default; -}; - class TRT_DEPRECATED IPluginCreator : public IPluginCreatorInterface { public: @@ -1071,15 +959,6 @@ class TRT_DEPRECATED IPluginCreator : public IPluginCreatorInterface }; } // namespace v_1_0 -//! -//! \class IPluginCreatorInterface -//! -//! \brief Base class for all plugin creator versions. -//! -//! \see IPluginCreator and IPluginRegistry -//! -using IPluginCreatorInterface = v_1_0::IPluginCreatorInterface; - //! //! \class IPluginCreator //! diff --git a/include/NvInferVersion.h b/include/NvInferVersion.h index 084b4ee16..d0d785124 100644 --- a/include/NvInferVersion.h +++ b/include/NvInferVersion.h @@ -24,9 +24,9 @@ #define NV_INFER_VERSION_H #define NV_TENSORRT_MAJOR 10 //!< TensorRT major version. -#define NV_TENSORRT_MINOR 5 //!< TensorRT minor version. +#define NV_TENSORRT_MINOR 6 //!< TensorRT minor version. #define NV_TENSORRT_PATCH 0 //!< TensorRT patch version. -#define NV_TENSORRT_BUILD 18 //!< TensorRT build number. +#define NV_TENSORRT_BUILD 26 //!< TensorRT build number. #define NV_TENSORRT_LWS_MAJOR 0 //!< TensorRT LWS major version. #define NV_TENSORRT_LWS_MINOR 0 //!< TensorRT LWS minor version. diff --git a/parsers/onnx b/parsers/onnx index 886aff917..4442153a4 160000 --- a/parsers/onnx +++ b/parsers/onnx @@ -1 +1 @@ -Subproject commit 886aff917b63f10a81c5f31e89752a3b46169623 +Subproject commit 4442153a4483c29e109241eb11752f3e59be62f8 diff --git a/plugin/README.md b/plugin/README.md index 8f024104c..4f13c98af 100644 --- a/plugin/README.md +++ b/plugin/README.md @@ -19,8 +19,8 @@ | [efficientNMSPlugin](efficientNMSPlugin) | EfficientNMS_TRT | 1 | | [efficientNMSONNXPlugin](efficientNMSPlugin) [DEPRECATED] | EfficientNMS_ONNX_TRT | 1 | | [embLayerNormPlugin](embLayerNormPlugin) [DEPRECATED]| CustomEmbLayerNormPluginDynamic | 1, 2, 3 | -| [embLayerNormPlugin](embLayerNormPlugin) | CustomEmbLayerNormPluginDynamic | 4, 5 | -| [fcPlugin](fcPlugin) | CustomFCPluginDynamic | 1 | +| [embLayerNormPlugin](embLayerNormPlugin) | CustomEmbLayerNormPluginDynamic | 4, 5, 6 | +| [fcPlugin](fcPlugin) [DEPRECATED] | CustomFCPluginDynamic | 1 | | [flattenConcat](flattenConcat) | FlattenConcat_TRT | 1 | | [geluPlugin](geluPlugin) [DEPRECATED] | CustomGeluPluginDynamic | 1 | | [generateDetectionPlugin](generateDetectionPlugin) | GenerateDetection_TRT | 1 | diff --git a/plugin/api/inferPlugin.cpp b/plugin/api/inferPlugin.cpp index 28c42cae2..5067963ea 100644 --- a/plugin/api/inferPlugin.cpp +++ b/plugin/api/inferPlugin.cpp @@ -16,11 +16,13 @@ */ #include "NvInfer.h" #include "NvInferPlugin.h" +#include "common/checkMacrosPlugin.h" +#include "common/plugin.h" +#include "roiAlignPlugin/roiAlignPlugin.h" +#if !TRT_WINML #include "batchTilePlugin/batchTilePlugin.h" #include "batchedNMSPlugin/batchedNMSPlugin.h" #include "clipPlugin/clipPlugin.h" -#include "common/checkMacrosPlugin.h" -#include "common/plugin.h" #include "coordConvACPlugin/coordConvACPlugin.h" #include "cropAndResizePlugin/cropAndResizePlugin.h" #include "decodeBbox3DPlugin/decodeBbox3D.h" @@ -57,7 +59,7 @@ #include "specialSlicePlugin/specialSlicePlugin.h" #include "splitPlugin/split.h" #include "voxelGeneratorPlugin/voxelGenerator.h" - +#endif #include #include #include @@ -181,6 +183,8 @@ extern "C" { bool initLibNvInferPlugins(void* logger, char const* libNamespace) { + initializePlugin(logger, libNamespace); +#if !TRT_WINML initializePlugin(logger, libNamespace); initializePlugin(logger, libNamespace); initializePlugin(logger, libNamespace); @@ -220,14 +224,14 @@ extern "C" initializePlugin(logger, libNamespace); initializePlugin(logger, libNamespace); initializePlugin(logger, libNamespace); - initializePlugin(logger, libNamespace); initializePlugin(logger, libNamespace); - initializePlugin(logger, libNamespace); initializePlugin(logger, libNamespace); + initializePlugin(logger, libNamespace); initializePlugin(logger, libNamespace); initializePlugin(logger, libNamespace); initializePlugin(logger, libNamespace); initializePlugin(logger, libNamespace); +#endif return true; } } // extern "C" diff --git a/plugin/embLayerNormPlugin/CustomEmbLayerNormPluginDynamic_PluginConfig.yaml b/plugin/embLayerNormPlugin/CustomEmbLayerNormPluginDynamic_PluginConfig.yaml index 6d8d1125c..62b95b359 100644 --- a/plugin/embLayerNormPlugin/CustomEmbLayerNormPluginDynamic_PluginConfig.yaml +++ b/plugin/embLayerNormPlugin/CustomEmbLayerNormPluginDynamic_PluginConfig.yaml @@ -16,9 +16,9 @@ # --- name: CustomEmbLayerNormPluginDynamic -interface: "IPluginV2DynamicExt" +interface: "IPluginV3" versions: - "1": + "6": inputs: - token_id - segment_id diff --git a/plugin/embLayerNormPlugin/README.md b/plugin/embLayerNormPlugin/README.md index ca0ed259c..46e9595f3 100644 --- a/plugin/embLayerNormPlugin/README.md +++ b/plugin/embLayerNormPlugin/README.md @@ -31,7 +31,7 @@ Assuming contiguous input masks, encodes the masks as a single number denoting t The version 1 `embLayerNormPlugin` takes three inputs; `token_id`, `segment_id`, and `input_mask`. The subsequent versions 2,3,4,5 (variable seqlen) take four inputs; `token_id`, `segment_id`, `cu_seqlen`, and `max_seqlen`. -### Version 1 +### Version 1 & 6 Inputs: - `token_id` An input sequence containing token ids. token_id is an `int32` tensor with shape `[S, B,]` where `S` is the sequence length and `B` is the batch size. @@ -56,7 +56,7 @@ The final output embedding is the sum of embeddings for the token, the segment a The `maskIdx` is a more compact representation of the input mask, consisting of the number of valid elements, assuming that the original mask was contiguous. For fixed sequence length version 1, the `maskIdx` is an `int32` tensor with shape `[B, packSize]` where `B` is batch size, `packSize` is the packed mask size that depends on the sequence length. -### Version >= 2 +### 6 > Version >= 2 Inputs: - `token_id` @@ -95,17 +95,17 @@ The final output embedding is the sum of embeddings for the token, the segment a The parameters are defined below and consists of the following attributes: -| Type | Parameter | Version | Description -|----------|----------------------------------------|----------------|-------------------------------------------------------- -|`int` |`output_fp16` | 1, 2, 3, 4, 5 |Integer encoding the DataType, set 0 when build FP32 network and set 1 when build FP32/INT8 network (0: FP32, 1: FP16) -|`int` |`full_mask` | 1 |Whether to output the full mask that works with the specialized multi-head-attention plugin kernels (this is deprecated, please use mha_type_id) -|`int` |`mha_type_id` | 1 |Integer encoding the multi-head-attention plugin DataType (0: FP32, 1: FP16, 2: INT8) -|`Weights` |`bert_embeddings_layernorm_beta` | 1, 2, 3, 4, 5 |Beta parameter for layer norm. Shape: `[E,]` where `E` is hidden size -|`Weights` |`bert_embeddings_layernorm_gamma` | 1, 2, 3, 4, 5 |Gamma parameter for layer norm. Shape: `[E,]` where `E` is hidden size -|`Weights` |`bert_embeddings_word_embeddings` | 1, 2, 3, 4, 5 |Token embedding matrix. Shape: `[word_vocab_size, E]` where `E` is hidden size -|`Weights` |`bert_embeddings_token_type_embeddings` | 1, 2, 3, 4, 5 |Token type embedding matrix. Shape: `[type_vocab_size, E]` where `E` is hidden size -|`Weights` |`bert_embeddings_position_embeddings` | 1, 2, 3, 4, 5 |Positional embedding matrix. Shape: `[S, E]` where `S` is the maximum sequence length and `E` is hidden size - +| Type | Parameter | Version | Description +|----------|----------------------------------------|-------------------|-------------------------------------------------------- +|`int` |`output_fp16` | 1, 2, 3, 4, 5, 6 |Integer encoding the DataType, set 0 when build FP32 network and set 1 when build FP32/INT8 network (0: FP32, 1: FP16) +|`int` |`full_mask` | 1, 6 |Whether to output the full mask that works with the specialized multi-head-attention plugin kernels (this is deprecated, please use mha_type_id) +|`int` |`mha_type_id` | 1, 6 |Integer encoding the multi-head-attention plugin DataType (0: FP32, 1: FP16, 2: INT8) +|`Weights` |`bert_embeddings_layernorm_beta` | 1, 2, 3, 4, 5, 6 |Beta parameter for layer norm. Shape: `[E,]` where `E` is hidden size +|`Weights` |`bert_embeddings_layernorm_gamma` | 1, 2, 3, 4, 5, 6 |Gamma parameter for layer norm. Shape: `[E,]` where `E` is hidden size +|`Weights` |`bert_embeddings_word_embeddings` | 1, 2, 3, 4, 5, 6 |Token embedding matrix. Shape: `[word_vocab_size, E]` where `E` is hidden size +|`Weights` |`bert_embeddings_token_type_embeddings` | 1, 2, 3, 4, 5, 6 |Token type embedding matrix. Shape: `[type_vocab_size, E]` where `E` is hidden size +|`Weights` |`bert_embeddings_position_embeddings` | 1, 2, 3, 4, 5, 6 |Positional embedding matrix. Shape: `[S, E]` where `S` is the maximum sequence length and `E` is hidden size +Note: version 1, 2, 3 are deprecated and will be removed in a future release; please use their corresponding updated versions: 6, 4, 5 respectively. ## Additional resources @@ -123,6 +123,9 @@ documentation. ## Changelog +September 2024: +Added `EmblayerNormPlugin` version 6 that mirrors version 1 in IO and attributes (but uses underlying `IPluginV3` implementation instead of the deprecated `IPluginV2DynamicExt` interface) + July 2024: Add `EmbLayerNormPlugin` versions 3 & 4 that duplicate the behavior of v2 and v3 plugins respectively, but implement the `IPluginV3` interface instead of the deprecated `IPluginV2DynamicExt` interface. Update this README with updated description of I/O and structure. diff --git a/plugin/embLayerNormPlugin/embLayerNormPlugin.cpp b/plugin/embLayerNormPlugin/embLayerNormPlugin.cpp index ab523971b..c682d48fa 100644 --- a/plugin/embLayerNormPlugin/embLayerNormPlugin.cpp +++ b/plugin/embLayerNormPlugin/embLayerNormPlugin.cpp @@ -32,8 +32,8 @@ using namespace nvinfer1::plugin::bert; namespace { -char const* EMB_LAYER_NORM_VERSION{"1"}; -char const* EMB_LAYER_NORM_NAME{"CustomEmbLayerNormPluginDynamic"}; +char const* gEmbLayerNormVersion{"6"}; +char const* gEmbLayerNormName{"CustomEmbLayerNormPluginDynamic"}; } // namespace // Static class fields initialization @@ -48,7 +48,6 @@ EmbLayerNormPluginDynamic::EmbLayerNormPluginDynamic(std::string const& name, Da : mLayerName(name) , mLd(beta.count) , mType(type) - , mUseFullMask(useFullMask) , mMhaType(mhaType) { // Assuming Weights.count is the number of elements and not bytes @@ -61,8 +60,9 @@ EmbLayerNormPluginDynamic::EmbLayerNormPluginDynamic(std::string const& name, Da mPosVocabSize = posEmb.count / mLd; mTokVocabSize = tokEmb.count / mLd; mSM = getSMVersion(); - // mS is set during configure - + mOutputFp16 = mType == DataType::kHALF ? 1 : 0; + mUseFullMask = static_cast(useFullMask); + // NOTE: mS is set during configure mBeta.convertAndCopy(beta, nvinfer1::DataType::kFLOAT); mGamma.convertAndCopy(gamma, nvinfer1::DataType::kFLOAT); mWordEmb.convertAndCopy(wordEmb, mType); @@ -76,235 +76,179 @@ EmbLayerNormPluginDynamic::EmbLayerNormPluginDynamic(std::string const& name, Da copyToDevice(mTokEmb, getWeightsSize(mTokEmb, mType), mTokEmbDev); } -EmbLayerNormPluginDynamic::EmbLayerNormPluginDynamic(std::string const& name, void const* data, size_t length) - : mLayerName(name) - , mGammaDev(nullptr) - , mBetaDev(nullptr) - , mWordEmbDev(nullptr) - , mTokEmbDev(nullptr) - , mPosEmbDev(nullptr) -{ - BERT_DEBUG_MSG("EmbLayerNormPluginDynamic deserialize."); - - // Deserialize in the same order as serialization - deserialize_value(&data, &length, &mType); - deserialize_value(&data, &length, &mMhaType); - deserialize_value(&data, &length, &mLd); - deserialize_value(&data, &length, &mS); - deserialize_value(&data, &length, &mWordVocabSize); - deserialize_value(&data, &length, &mPosVocabSize); - deserialize_value(&data, &length, &mTokVocabSize); - deserialize_value(&data, &length, &mUseFullMask); - deserialize_value(&data, &length, &mSM); - - char const* d = static_cast(data); - mBeta.convertAndCopy(d, mLd, nvinfer1::DataType::kFLOAT); - mGamma.convertAndCopy(d, mLd, nvinfer1::DataType::kFLOAT); - mWordEmb.convertAndCopy(d, mLd * mWordVocabSize, mType); - mPosEmb.convertAndCopy(d, mLd * mPosVocabSize, mType); - mTokEmb.convertAndCopy(d, mLd * mTokVocabSize, mType); - - copyToDevice(mGamma, sizeof(float) * mGamma.count, mGammaDev); - copyToDevice(mBeta, sizeof(float) * mBeta.count, mBetaDev); - copyToDevice(mWordEmb, getWeightsSize(mWordEmb, mType), mWordEmbDev); - copyToDevice(mPosEmb, getWeightsSize(mPosEmb, mType), mPosEmbDev); - copyToDevice(mTokEmb, getWeightsSize(mTokEmb, mType), mTokEmbDev); -} - -// IPluginV2DynamicExt Methods -IPluginV2DynamicExt* EmbLayerNormPluginDynamic::clone() const noexcept +EmbLayerNormPluginDynamic::~EmbLayerNormPluginDynamic() { try { - BERT_DEBUG_MSG("EmbLayerNormPluginDynamic clone."); - - auto p = new EmbLayerNormPluginDynamic( - mLayerName, mType, mMhaType, mBeta, mGamma, mWordEmb, mPosEmb, mTokEmb, mUseFullMask); - p->mS = mS; - p->setPluginNamespace(mNamespace.c_str()); - - return p; + // This gets called when the network containing plugin is destroyed + mGammaDev.reset(nullptr); + mBetaDev.reset(nullptr); + mWordEmbDev.reset(nullptr); + mPosEmbDev.reset(nullptr); + mTokEmbDev.reset(nullptr); + // delete this; TRT or the creator of the plugin will delete this plugin object } catch (std::exception const& e) { caughtError(e); } - return nullptr; } -DimsExprs EmbLayerNormPluginDynamic::getOutputDimensions( - int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept +////// +// IPluginV3 method definitions: +// - getCapabilityInterface() (Base) +// - clone() (HFace, MTron) +////// +IPluginCapability* EmbLayerNormPluginDynamic::getCapabilityInterface(PluginCapabilityType type) noexcept { try { - // Input should be input ids and token ids and the input mask - // Output should be the embeddings tensor and mask indices - PLUGIN_ASSERT(nbInputs == 3); - - PLUGIN_ASSERT(inputs[0].nbDims == 2); // BxS - PLUGIN_ASSERT(inputs[0].nbDims == inputs[1].nbDims); - PLUGIN_ASSERT(inputs[0].nbDims == inputs[2].nbDims); - - PLUGIN_ASSERT(outputIndex == 0 || outputIndex == 1); - - if (outputIndex == 0) + if (type == PluginCapabilityType::kBUILD) { - DimsExprs ret; - ret.nbDims = 5; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = inputs[0].d[1]; - ret.d[2] = exprBuilder.constant(mLd); - ret.d[3] = exprBuilder.constant(1); - ret.d[4] = exprBuilder.constant(1); - return ret; - } - - DimsExprs ret; - ret.nbDims = 2; - ret.d[0] = inputs[0].d[BDIM]; - auto cms0 = exprBuilder.constant(unfusedMaskSize); - - // this code must match getMHAMaskPackedSize in bertCommon.h - bool const isSmOK - = (mSM == kSM_75 || mSM == kSM_80 || mSM == kSM_86 || mSM == kSM_87 || mSM == kSM_89 || mSM == kSM_90); - bool const isPrecisionOK = (mMhaType == nvinfer1::DataType::kHALF || mMhaType == nvinfer1::DataType::kINT8); - if (mUseFullMask || (isSmOK && isPrecisionOK)) - { - // support 128, 384 in both int8 and fp16 - auto cms128 = exprBuilder.constant(packedMaskSize128); - auto cms384 = exprBuilder.constant(packedMaskSize384); - auto c128 = exprBuilder.constant(128); - auto c384 = exprBuilder.constant(384); - auto is128 = exprBuilder.operation(DimensionOperation::kEQUAL, *inputs[0].d[SDIM], *c128); - auto is384 = exprBuilder.operation(DimensionOperation::kEQUAL, *inputs[0].d[SDIM], *c384); - auto sel128 = exprBuilder.operation(DimensionOperation::kPROD, *is128, *cms128); - auto sel384 = exprBuilder.operation(DimensionOperation::kPROD, *is384, *cms384); - auto maskSize = exprBuilder.operation(DimensionOperation::kSUM, *sel384, *sel128); - - // support 64, 96 in both int8 and fp16 - auto cms64 = exprBuilder.constant(packedMaskSize64); - auto cms96 = exprBuilder.constant(packedMaskSize96); - auto c64 = exprBuilder.constant(64); - auto c96 = exprBuilder.constant(96); - - auto is64 = exprBuilder.operation(DimensionOperation::kEQUAL, *inputs[0].d[SDIM], *c64); - auto is96 = exprBuilder.operation(DimensionOperation::kEQUAL, *inputs[0].d[SDIM], *c96); - auto sel64 = exprBuilder.operation(DimensionOperation::kPROD, *is64, *cms64); - auto sel96 = exprBuilder.operation(DimensionOperation::kPROD, *is96, *cms96); - auto maskSize2 = exprBuilder.operation(DimensionOperation::kSUM, *sel64, *sel96); - maskSize = exprBuilder.operation(DimensionOperation::kSUM, *maskSize, *maskSize2); - - auto is0 = exprBuilder.operation(DimensionOperation::kEQUAL, *maskSize, *exprBuilder.constant(0)); - auto sel0 = exprBuilder.operation(DimensionOperation::kPROD, *is0, *cms0); - auto combinedMaskSize = exprBuilder.operation(DimensionOperation::kSUM, *maskSize, *sel0); - ret.d[1] = combinedMaskSize; + return static_cast(this); } - else + if (type == PluginCapabilityType::kRUNTIME) { - ret.d[1] = cms0; + return static_cast(this); } - - return ret; + PLUGIN_ASSERT(type == PluginCapabilityType::kCORE); + return static_cast(this); } catch (std::exception const& e) { caughtError(e); } - return DimsExprs{}; + return nullptr; } -bool EmbLayerNormPluginDynamic::supportsFormatCombination( - int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept +IPluginV3* EmbLayerNormPluginDynamic::clone() noexcept { - // 3 inputs of size BxS - PLUGIN_ASSERT(nbInputs == 3); - PLUGIN_ASSERT(nbOutputs == 2); - - PluginTensorDesc const& desc = inOut[pos]; - if (desc.format != TensorFormat::kLINEAR) + try { - return false; + BERT_DEBUG_MSG("EmbLayerNormPluginDynamic clone."); + + auto p = new EmbLayerNormPluginDynamic( + mLayerName, mType, mMhaType, mBeta, mGamma, mWordEmb, mPosEmb, mTokEmb, mUseFullMask == 1); + p->mS = mS; + p->setPluginNamespace(mNamespace.c_str()); + + return p; } - if (pos == 0) + catch (std::exception const& e) { - return desc.type == DataType::kINT32 && desc.dims.nbDims == 2; + caughtError(e); } + return nullptr; +} - PluginTensorDesc const& prev = inOut[pos - 1]; - if (pos == 1 || pos == 2) +// End IPluginV3 method definitions + +////// +// IPluginV3OneRuntime method definitions: +// - getFieldsToSerialize() +// - onShapeChange() +// - attachToContext() +// - enqueue() +///// + +PluginFieldCollection const* EmbLayerNormPluginDynamic::getFieldsToSerialize() noexcept +{ + mDataToSerialize.clear(); + mDataToSerialize.emplace_back("output_fp16", &mOutputFp16, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("full_mask", &mUseFullMask, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("mha_type_id", &mMhaType, PluginFieldType::kINT32, 1); + mDataToSerialize.emplace_back("bert_embeddings_layernorm_beta", static_cast(mBeta.values), + PluginFieldType::kFLOAT32, mBeta.count); + mDataToSerialize.emplace_back("bert_embeddings_layernorm_gamma", static_cast(mGamma.values), + PluginFieldType::kFLOAT32, mGamma.count); + if (mOutputFp16) { - return desc.type == DataType::kINT32 && desc.dims.nbDims == 2 && desc.dims.d[BDIM] == prev.dims.d[BDIM] - && desc.dims.d[SDIM] == prev.dims.d[SDIM]; + mDataToSerialize.emplace_back("bert_embeddings_word_embeddings", static_cast(mWordEmb.values), + PluginFieldType::kFLOAT16, mWordEmb.count); + mDataToSerialize.emplace_back("bert_embeddings_token_type_embeddings", static_cast(mTokEmb.values), + PluginFieldType::kFLOAT16, mTokEmb.count); + mDataToSerialize.emplace_back("bert_embeddings_position_embeddings", static_cast(mPosEmb.values), + PluginFieldType::kFLOAT16, mPosEmb.count); } - - // embedded sequence - if (pos == 3) + else { - return desc.type == mType && desc.dims.nbDims == 5 && desc.dims.d[BDIM] == prev.dims.d[BDIM] - && desc.dims.d[SDIM] == prev.dims.d[SDIM] && desc.dims.d[3] == 1 && desc.dims.d[4] == 1; + mDataToSerialize.emplace_back("bert_embeddings_word_embeddings", static_cast(mWordEmb.values), + PluginFieldType::kFLOAT32, mWordEmb.count); + mDataToSerialize.emplace_back("bert_embeddings_token_type_embeddings", + static_cast(mTokEmb.values), PluginFieldType::kFLOAT32, mTokEmb.count); + mDataToSerialize.emplace_back("bert_embeddings_position_embeddings", static_cast(mPosEmb.values), + PluginFieldType::kFLOAT32, mPosEmb.count); } - // mask - return desc.type == DataType::kINT32; + mFCToSerialize.nbFields = mDataToSerialize.size(); + mFCToSerialize.fields = mDataToSerialize.data(); + return &mFCToSerialize; } -void EmbLayerNormPluginDynamic::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs, - DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept +int32_t EmbLayerNormPluginDynamic::onShapeChange( + PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) noexcept { BERT_DEBUG_MSG("EmbLayerNormPluginDynamic configurePlugin."); + try + { + // Validate input arguments + PLUGIN_ASSERT(nbOutputs == 2); + PLUGIN_ASSERT(nbInputs == 3); - // Validate input arguments - PLUGIN_ASSERT(nbOutputs == 2); - PLUGIN_ASSERT(nbInputs == 3); + PLUGIN_ASSERT(inputs[0].dims.nbDims == 2); + int32_t const S = inputs[0].dims.d[SDIM]; + mS = S; + int32_t const B = inputs[0].dims.d[BDIM]; + TRT_UNUSED B; + PLUGIN_ASSERT(mS == static_cast(inputs[1].dims.d[SDIM])); + PLUGIN_ASSERT(B == inputs[1].dims.d[BDIM]); + PLUGIN_ASSERT(mS == static_cast(inputs[2].dims.d[SDIM])); + PLUGIN_ASSERT(B == inputs[2].dims.d[BDIM]); + + PLUGIN_ASSERT(outputs[0].dims.nbDims == 5); + PLUGIN_ASSERT(static_cast(outputs[0].dims.d[SDIM]) == mS); + PLUGIN_ASSERT(outputs[0].dims.d[BDIM] == B); + PLUGIN_ASSERT(static_cast(outputs[0].dims.d[2]) == mLd); + PLUGIN_ASSERT(outputs[0].dims.d[3] == 1); + PLUGIN_ASSERT(outputs[0].dims.d[4] == 1); + + if (mUseFullMask) + { + // user force full_mask + PLUGIN_ASSERT(outputs[1].dims.nbDims == 2); + PLUGIN_ASSERT(outputs[1].dims.d[0] == B); + PLUGIN_ASSERT((outputs[1].dims.d[1] == -1) || (outputs[1].dims.d[1] == packedMaskSize384) + || (outputs[1].dims.d[1] == packedMaskSize128)); + } + else + { + // auto detect using mhatype + if (S != -1 && B != -1) + { + PLUGIN_ASSERT(outputs[1].dims.nbDims == 2); + PLUGIN_ASSERT(outputs[1].dims.d[0] == B); + int32_t packedSize = getMHAMaskPackedSize(mSM, mMhaType, S); + TRT_UNUSED packedSize; + PLUGIN_ASSERT(outputs[1].dims.d[1] == -1 || outputs[1].dims.d[1] == packedSize); + } + } - PLUGIN_ASSERT(inputs[0].desc.dims.nbDims == 2); - int32_t const S = inputs[0].desc.dims.d[SDIM]; - mS = S; - int32_t const B = inputs[0].desc.dims.d[BDIM]; - TRT_UNUSED B; - PLUGIN_ASSERT(mS == static_cast(inputs[1].desc.dims.d[SDIM])); - PLUGIN_ASSERT(B == inputs[1].desc.dims.d[BDIM]); - PLUGIN_ASSERT(mS == static_cast(inputs[2].desc.dims.d[SDIM])); - PLUGIN_ASSERT(B == inputs[2].desc.dims.d[BDIM]); - - PLUGIN_ASSERT(outputs[0].desc.dims.nbDims == 5); - PLUGIN_ASSERT(static_cast(outputs[0].desc.dims.d[SDIM]) == mS); - PLUGIN_ASSERT(outputs[0].desc.dims.d[BDIM] == B); - PLUGIN_ASSERT(static_cast(outputs[0].desc.dims.d[2]) == mLd); - PLUGIN_ASSERT(outputs[0].desc.dims.d[3] == 1); - PLUGIN_ASSERT(outputs[0].desc.dims.d[4] == 1); - - if (mUseFullMask) - { - // user force full_mask - PLUGIN_ASSERT(outputs[1].desc.dims.nbDims == 2); - PLUGIN_ASSERT(outputs[1].desc.dims.d[0] == B); - PLUGIN_ASSERT((outputs[1].desc.dims.d[1] == -1) || (outputs[1].desc.dims.d[1] == packedMaskSize384) - || (outputs[1].desc.dims.d[1] == packedMaskSize128)); + PLUGIN_ASSERT(inputs[0].type == DataType::kINT32); + PLUGIN_ASSERT(inputs[1].type == DataType::kINT32); + PLUGIN_ASSERT(inputs[2].type == DataType::kINT32); + PLUGIN_ASSERT(outputs[0].type == mType); + PLUGIN_ASSERT(outputs[1].type == DataType::kINT32); + return pluginStatus_t::STATUS_SUCCESS; } - else + catch (std::exception const& e) { - // auto detect using mhatype - if (S != -1 && B != -1) - { - PLUGIN_ASSERT(outputs[1].desc.dims.nbDims == 2); - PLUGIN_ASSERT(outputs[1].desc.dims.d[0] == B); - int32_t packedSize = getMHAMaskPackedSize(mSM, mMhaType, S); - TRT_UNUSED packedSize; - PLUGIN_ASSERT(outputs[1].desc.dims.d[1] == -1 || outputs[1].desc.dims.d[1] == packedSize); - } + caughtError(e); } - - PLUGIN_ASSERT(inputs[0].desc.type == DataType::kINT32); - PLUGIN_ASSERT(inputs[1].desc.type == DataType::kINT32); - PLUGIN_ASSERT(inputs[2].desc.type == DataType::kINT32); - PLUGIN_ASSERT(outputs[0].desc.type == mType); - PLUGIN_ASSERT(outputs[1].desc.type == DataType::kINT32); + return pluginStatus_t::STATUS_FAILURE; } -size_t EmbLayerNormPluginDynamic::getWorkspaceSize( - PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept +IPluginV3* EmbLayerNormPluginDynamic::attachToContext(IPluginResourceContext* context) noexcept { - return 0; + return clone(); } int32_t EmbLayerNormPluginDynamic::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* /* outputDesc */, @@ -394,92 +338,185 @@ int32_t EmbLayerNormPluginDynamic::enqueue(PluginTensorDesc const* inputDesc, Pl return STATUS_FAILURE; } -// IPluginV2Ext Methods -DataType EmbLayerNormPluginDynamic::getOutputDataType( - int32_t index, DataType const* inputTypes, int32_t nbInputs) const noexcept -{ +// end IPluginV3OneRuntime method definitions - PLUGIN_ASSERT(index == 0 || index == 1); - if (index == 0) - { - PLUGIN_ASSERT(mType == DataType::kHALF || mType == DataType::kFLOAT); - return mType; - } - return DataType::kINT32; -} +/////// +// IPluginV3OneBuild method definitions +// - getNbOutputs() +// - supportsFormatCombination() +// - getOutputShapes +// - getOutputDataTypes() +// - configurePlugin() +// - getWorkSpaceSize() +////// -// IPluginV2 Methods -char const* EmbLayerNormPluginDynamic::getPluginType() const noexcept +int32_t EmbLayerNormPluginDynamic::getNbOutputs() const noexcept { - return EMB_LAYER_NORM_NAME; + return 2; } -char const* EmbLayerNormPluginDynamic::getPluginVersion() const noexcept +bool EmbLayerNormPluginDynamic::supportsFormatCombination( + int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept { - return EMB_LAYER_NORM_VERSION; + // 3 inputs of size BxS + PLUGIN_ASSERT(nbInputs == 3); + PLUGIN_ASSERT(nbOutputs == 2); + + PluginTensorDesc const& desc = inOut[pos].desc; + if (desc.format != TensorFormat::kLINEAR) + { + return false; + } + if (pos == 0) + { + return desc.type == DataType::kINT32 && desc.dims.nbDims == 2; + } + + PluginTensorDesc const& prev = inOut[pos - 1].desc; + if (pos == 1 || pos == 2) + { + return desc.type == DataType::kINT32 && desc.dims.nbDims == 2 && desc.dims.d[BDIM] == prev.dims.d[BDIM] + && desc.dims.d[SDIM] == prev.dims.d[SDIM]; + } + + // embedded sequence + if (pos == 3) + { + return desc.type == mType && desc.dims.nbDims == 5 && desc.dims.d[BDIM] == prev.dims.d[BDIM] + && desc.dims.d[SDIM] == prev.dims.d[SDIM] && desc.dims.d[3] == 1 && desc.dims.d[4] == 1; + } + // mask + return desc.type == DataType::kINT32; } -int32_t EmbLayerNormPluginDynamic::getNbOutputs() const noexcept +int32_t EmbLayerNormPluginDynamic::getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, + DimsExprs const* shapeInputs, int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, + IExprBuilder& exprBuilder) noexcept { - return 2; + try + { + // Input should be input ids and token ids and the input mask + // Output should be the embeddings tensor and mask indices + PLUGIN_ASSERT(nbInputs == 3); + PLUGIN_ASSERT(inputs != nullptr); + PLUGIN_ASSERT(inputs[0].nbDims == 2); // BxS + PLUGIN_ASSERT(inputs[0].nbDims == inputs[1].nbDims); + PLUGIN_ASSERT(inputs[0].nbDims == inputs[2].nbDims); + + PLUGIN_ASSERT(nbOutputs == 2); + PLUGIN_ASSERT(outputs != nullptr); + + // output 0: embeddings tensor + outputs[0].nbDims = 5; + outputs[0].d[0] = inputs[0].d[0]; + outputs[0].d[1] = inputs[0].d[1]; + outputs[0].d[2] = exprBuilder.constant(mLd); + outputs[0].d[3] = exprBuilder.constant(1); + outputs[0].d[4] = exprBuilder.constant(1); + + // output 1: mask indices + outputs[1].nbDims = 2; + outputs[1].d[0] = inputs[0].d[BDIM]; + auto cms0 = exprBuilder.constant(unfusedMaskSize); + + // this code must match getMHAMaskPackedSize in bertCommon.h + bool const isSmOK + = (mSM == kSM_75 || mSM == kSM_80 || mSM == kSM_86 || mSM == kSM_87 || mSM == kSM_89 || mSM == kSM_90); + bool const isPrecisionOK = (mMhaType == nvinfer1::DataType::kHALF || mMhaType == nvinfer1::DataType::kINT8); + if (mUseFullMask || (isSmOK && isPrecisionOK)) + { + // support 128, 384 in both int8 and fp16 + auto cms128 = exprBuilder.constant(packedMaskSize128); + auto cms384 = exprBuilder.constant(packedMaskSize384); + auto c128 = exprBuilder.constant(128); + auto c384 = exprBuilder.constant(384); + auto is128 = exprBuilder.operation(DimensionOperation::kEQUAL, *inputs[0].d[SDIM], *c128); + auto is384 = exprBuilder.operation(DimensionOperation::kEQUAL, *inputs[0].d[SDIM], *c384); + auto sel128 = exprBuilder.operation(DimensionOperation::kPROD, *is128, *cms128); + auto sel384 = exprBuilder.operation(DimensionOperation::kPROD, *is384, *cms384); + auto maskSize = exprBuilder.operation(DimensionOperation::kSUM, *sel384, *sel128); + + // support 64, 96 in both int8 and fp16 + auto cms64 = exprBuilder.constant(packedMaskSize64); + auto cms96 = exprBuilder.constant(packedMaskSize96); + auto c64 = exprBuilder.constant(64); + auto c96 = exprBuilder.constant(96); + + auto is64 = exprBuilder.operation(DimensionOperation::kEQUAL, *inputs[0].d[SDIM], *c64); + auto is96 = exprBuilder.operation(DimensionOperation::kEQUAL, *inputs[0].d[SDIM], *c96); + auto sel64 = exprBuilder.operation(DimensionOperation::kPROD, *is64, *cms64); + auto sel96 = exprBuilder.operation(DimensionOperation::kPROD, *is96, *cms96); + auto maskSize2 = exprBuilder.operation(DimensionOperation::kSUM, *sel64, *sel96); + maskSize = exprBuilder.operation(DimensionOperation::kSUM, *maskSize, *maskSize2); + + auto is0 = exprBuilder.operation(DimensionOperation::kEQUAL, *maskSize, *exprBuilder.constant(0)); + auto sel0 = exprBuilder.operation(DimensionOperation::kPROD, *is0, *cms0); + auto combinedMaskSize = exprBuilder.operation(DimensionOperation::kSUM, *maskSize, *sel0); + outputs[1].d[1] = combinedMaskSize; + } + else + { + outputs[1].d[1] = cms0; + } + return pluginStatus_t::STATUS_SUCCESS; + } + catch (std::exception const& e) + { + caughtError(e); + } + return pluginStatus_t::STATUS_FAILURE; } -int32_t EmbLayerNormPluginDynamic::initialize() noexcept +int32_t EmbLayerNormPluginDynamic::getOutputDataTypes( + DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept { - return 0; + try + { + PLUGIN_ASSERT(outputTypes != nullptr); + PLUGIN_ASSERT(nbOutputs == 2); + PLUGIN_ASSERT(inputTypes != nullptr); + PLUGIN_ASSERT(nbInputs == 3); + PLUGIN_ASSERT(mType == DataType::kHALF || mType == DataType::kFLOAT); + outputTypes[0] = mType; + outputTypes[1] = DataType::kINT32; + return pluginStatus_t::STATUS_SUCCESS; + } + catch (std::exception const& e) + { + caughtError(e); + } + return pluginStatus_t::STATUS_FAILURE; } -void EmbLayerNormPluginDynamic::terminate() noexcept +int32_t EmbLayerNormPluginDynamic::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs, + DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept { - BERT_DEBUG_MSG("EmbLayerNormPluginDynamic terminate."); + return pluginStatus_t::STATUS_SUCCESS; } -size_t EmbLayerNormPluginDynamic::getSerializationSize() const noexcept +size_t EmbLayerNormPluginDynamic::getWorkspaceSize(DynamicPluginTensorDesc const* inputs, int32_t nbInputs, + DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept { - size_t const wordSize = getElementSize(mType); - return sizeof(mType) // type - + sizeof(mMhaType) // mha plugin datatype - + sizeof(mLd) * 5 // mLd, mS, m*VocabSize - + sizeof(mUseFullMask) // mask type - + sizeof(mSM) // smversion - + 2 * sizeof(float) * mLd // beta + gamma - + wordSize * mLd * mWordVocabSize // word emb - + wordSize * mLd * mPosVocabSize // pos emb - + wordSize * mLd * mTokVocabSize // tok emb - ; + return 0; } -void EmbLayerNormPluginDynamic::serialize(void* buffer) const noexcept +// End IPluginV3OneBuild method definitions + +////// +// IPluginV3OneCore method definitions +// - getPluginVersion() +// - getPluginName() +// - getPluginNamespace() +// - setPluginNamespace() +////// +char const* EmbLayerNormPluginDynamic::getPluginVersion() const noexcept { - serialize_value(&buffer, mType); - serialize_value(&buffer, mMhaType); - serialize_value(&buffer, mLd); - serialize_value(&buffer, mS); - serialize_value(&buffer, mWordVocabSize); - serialize_value(&buffer, mPosVocabSize); - serialize_value(&buffer, mTokVocabSize); - serialize_value(&buffer, mUseFullMask); - serialize_value(&buffer, mSM); - - char* d = static_cast(buffer); - serFromDev(d, mBetaDev.get(), mLd); - serFromDev(d, mGammaDev.get(), mLd); - size_t const wordSize = getElementSize(mType); - serFromDev(d, static_cast(mWordEmbDev.get()), mLd * mWordVocabSize * wordSize); - serFromDev(d, static_cast(mPosEmbDev.get()), mLd * mPosVocabSize * wordSize); - serFromDev(d, static_cast(mTokEmbDev.get()), mLd * mTokVocabSize * wordSize); + return gEmbLayerNormVersion; } -void EmbLayerNormPluginDynamic::destroy() noexcept +char const* EmbLayerNormPluginDynamic::getPluginName() const noexcept { - BERT_DEBUG_MSG("EmbLayerNormPluginDynamic destroy."); - // This gets called when the network containing plugin is destroyed - mGammaDev.reset(nullptr); - mBetaDev.reset(nullptr); - mWordEmbDev.reset(nullptr); - mPosEmbDev.reset(nullptr); - mTokEmbDev.reset(nullptr); - delete this; + return gEmbLayerNormName; } void EmbLayerNormPluginDynamic::setPluginNamespace(char const* libNamespace) noexcept @@ -499,10 +536,14 @@ char const* EmbLayerNormPluginDynamic::getPluginNamespace() const noexcept return mNamespace.c_str(); } -/////////////////////// +// End IPluginV3OneCore method definitions + +//////////////////////////// Plugin Creator member definitions ///////////////////////////// EmbLayerNormPluginDynamicCreator::EmbLayerNormPluginDynamicCreator() { + static std::mutex sMutex; + std::lock_guard lock(sMutex); mPluginAttributes.clear(); mPluginAttributes.emplace_back(PluginField("bert_embeddings_layernorm_beta")); mPluginAttributes.emplace_back(PluginField("bert_embeddings_layernorm_gamma")); @@ -518,12 +559,12 @@ EmbLayerNormPluginDynamicCreator::EmbLayerNormPluginDynamicCreator() char const* EmbLayerNormPluginDynamicCreator::getPluginName() const noexcept { - return EMB_LAYER_NORM_NAME; + return gEmbLayerNormName; } char const* EmbLayerNormPluginDynamicCreator::getPluginVersion() const noexcept { - return EMB_LAYER_NORM_VERSION; + return gEmbLayerNormVersion; } PluginFieldCollection const* EmbLayerNormPluginDynamicCreator::getFieldNames() noexcept @@ -531,7 +572,8 @@ PluginFieldCollection const* EmbLayerNormPluginDynamicCreator::getFieldNames() n return &mFC; } -IPluginV2* EmbLayerNormPluginDynamicCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept +IPluginV3* EmbLayerNormPluginDynamicCreator::createPlugin( + char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept { try { @@ -630,22 +672,6 @@ IPluginV2* EmbLayerNormPluginDynamicCreator::createPlugin(char const* name, Plug return nullptr; } -IPluginV2* EmbLayerNormPluginDynamicCreator::deserializePlugin( - char const* name, void const* serialData, size_t serialLength) noexcept -{ - try - { - // This object will be deleted when the network is destroyed, which will - // call EmbLayerNormPluginDynamic::destroy() - return new EmbLayerNormPluginDynamic(name, serialData, serialLength); - } - catch (std::exception const& e) - { - caughtError(e); - } - return nullptr; -} - void EmbLayerNormPluginDynamicCreator::setPluginNamespace(char const* libNamespace) noexcept { try diff --git a/plugin/embLayerNormPlugin/embLayerNormPlugin.h b/plugin/embLayerNormPlugin/embLayerNormPlugin.h index 5eb40958a..06c2bb11b 100644 --- a/plugin/embLayerNormPlugin/embLayerNormPlugin.h +++ b/plugin/embLayerNormPlugin/embLayerNormPlugin.h @@ -45,52 +45,76 @@ int32_t embSkipLayerNorm(cudaStream_t stream, int32_t ld, int32_t B, int32_t S, cudaError_t convertMask(uint32_t const S, uint32_t const B, uint32_t const warps_m, uint32_t const warps_n, uint32_t const warps_k, int32_t const* inputMaskSB, uint32_t* inputMaskX, cudaStream_t stream); -class EmbLayerNormPluginDynamic : public nvinfer1::IPluginV2DynamicExt +class EmbLayerNormPluginDynamic : public IPluginV3, + public IPluginV3OneCore, + public IPluginV3OneBuild, + public IPluginV3OneRuntime { public: EmbLayerNormPluginDynamic(std::string const& name, nvinfer1::DataType const type, nvinfer1::DataType const mhaType, nvinfer1::Weights const& beta, nvinfer1::Weights const& gamma, nvinfer1::Weights const& word_emb, nvinfer1::Weights const& pos_emb, nvinfer1::Weights const& tok_emb, bool const useFullMask); - EmbLayerNormPluginDynamic(std::string const& name, void const* data, size_t length); - // It doesn't make sense to make EmbLayerNormPluginDynamic without arguments, so we // delete default constructor. EmbLayerNormPluginDynamic() = delete; - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs, - nvinfer1::IExprBuilder& exprBuilder) noexcept override; + ~EmbLayerNormPluginDynamic() override; + + // IPluginV3 Methods + // NOTE: since this is itself is an abstract class, the rest of virtual methods defined in its children classes + IPluginCapability* getCapabilityInterface(PluginCapabilityType type) noexcept override; + // end of IPluginV3 Methods + + // IPluginV3OneCore Methods + char const* getPluginName() const noexcept override; + + char const* getPluginNamespace() const noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept; + + char const* getPluginVersion() const noexcept override; + // end of IPluginV3OneCore Methods + + // IPluginV3Build Methods bool supportsFormatCombination( - int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override; - void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs, - nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override; - size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, - nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override; + + int32_t configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, + int32_t nbOutputs) noexcept override; + + size_t getWorkspaceSize(DynamicPluginTensorDesc const* inputs, int32_t nbInputs, + DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + + int32_t getOutputDataTypes( + DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept override; + + int32_t getNbOutputs() const noexcept override; + + int32_t getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs, + int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept override; + // end IPluginV3Build Methods + + // IPluginV3Runtime Methods + IPluginV3* clone() noexcept; + + int32_t onShapeChange( + PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out, int32_t nbOutputs) noexcept override; + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType( - int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override; + IPluginV3* attachToContext(IPluginResourceContext* context) noexcept override; - // IPluginV2 Methods - char const* getPluginType() const noexcept override; - char const* getPluginVersion() const noexcept override; - int32_t getNbOutputs() const noexcept override; - int32_t initialize() noexcept override; - void terminate() noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void* buffer) const noexcept override; - void destroy() noexcept override; - void setPluginNamespace(char const* pluginNamespace) noexcept override; - char const* getPluginNamespace() const noexcept override; + PluginFieldCollection const* getFieldsToSerialize() noexcept override; + // end IPluginV3Runtime Methods private: + // metadata fields std::string const mLayerName; std::string mNamespace; + // device-side bert::cuda_unique_ptr mGammaDev; bert::cuda_unique_ptr mBetaDev; bert::cuda_unique_ptr mWordEmbDev; @@ -101,26 +125,29 @@ class EmbLayerNormPluginDynamic : public nvinfer1::IPluginV2DynamicExt size_t mWordVocabSize; size_t mPosVocabSize; size_t mTokVocabSize; + + // members that partcipate in ser/deserialization bert::WeightsWithOwnership mBeta; bert::WeightsWithOwnership mGamma; bert::WeightsWithOwnership mWordEmb; bert::WeightsWithOwnership mTokEmb; bert::WeightsWithOwnership mPosEmb; nvinfer1::DataType mType; - bool mUseFullMask; + int32_t mOutputFp16; + int32_t mUseFullMask; nvinfer1::DataType mMhaType; int32_t mSM; - using IPluginV2::getOutputDimensions; - using IPluginV2::getWorkspaceSize; - using IPluginV2::enqueue; - using IPluginV2Ext::configurePlugin; + // IPluginV3 serialization related + std::vector mDataToSerialize; + nvinfer1::PluginFieldCollection mFCToSerialize; }; -class EmbLayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator +class EmbLayerNormPluginDynamicCreator : public nvinfer1::IPluginCreatorV3One { public: EmbLayerNormPluginDynamicCreator(); + ~EmbLayerNormPluginDynamicCreator() override = default; char const* getPluginName() const noexcept override; @@ -128,12 +155,9 @@ class EmbLayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; - - nvinfer1::IPluginV2* deserializePlugin( - char const* name, void const* serialData, size_t serialLength) noexcept override; + IPluginV3* createPlugin(char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept override; - void setPluginNamespace(char const* pluginNamespace) noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept; char const* getPluginNamespace() const noexcept override; diff --git a/plugin/embLayerNormPlugin/embLayerNormPluginLegacy.cpp b/plugin/embLayerNormPlugin/embLayerNormPluginLegacy.cpp new file mode 100644 index 000000000..6c037189e --- /dev/null +++ b/plugin/embLayerNormPlugin/embLayerNormPluginLegacy.cpp @@ -0,0 +1,669 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#if CUDA_VERSION >= 10010 + +#include +#include +#include + +#include "NvInfer.h" +#include "common/serialize.hpp" +#include "embLayerNormPluginLegacy.h" + +using namespace nvinfer1; +using namespace nvinfer1::plugin; +using namespace nvinfer1::plugin::bert; + +namespace +{ +char const* gEmbLayerNormVersion{"1"}; +char const* gEmbLayerNormName{"CustomEmbLayerNormPluginDynamic"}; +} // namespace + +// Static class fields initialization +PluginFieldCollection EmbLayerNormPluginDynamicLegacyCreator::mFC{}; +std::vector EmbLayerNormPluginDynamicLegacyCreator::mPluginAttributes; + +REGISTER_TENSORRT_PLUGIN(EmbLayerNormPluginDynamicLegacyCreator); + +EmbLayerNormPluginDynamicLegacy::EmbLayerNormPluginDynamicLegacy(std::string const& name, DataType const type, + DataType const mhaType, Weights const& beta, Weights const& gamma, Weights const& wordEmb, Weights const& posEmb, + Weights const& tokEmb, bool const useFullMask) + : mLayerName(name) + , mLd(beta.count) + , mType(type) + , mUseFullMask(useFullMask) + , mMhaType(mhaType) +{ + // Assuming Weights.count is the number of elements and not bytes + PLUGIN_VALIDATE(beta.count == gamma.count); + PLUGIN_VALIDATE(mLd > 0U); + PLUGIN_VALIDATE(wordEmb.count % mLd == 0); + PLUGIN_VALIDATE(posEmb.count % mLd == 0); + PLUGIN_VALIDATE(tokEmb.count % mLd == 0); + mWordVocabSize = wordEmb.count / mLd; + mPosVocabSize = posEmb.count / mLd; + mTokVocabSize = tokEmb.count / mLd; + mSM = getSMVersion(); + // mS is set during configure + + mBeta.convertAndCopy(beta, nvinfer1::DataType::kFLOAT); + mGamma.convertAndCopy(gamma, nvinfer1::DataType::kFLOAT); + mWordEmb.convertAndCopy(wordEmb, mType); + mTokEmb.convertAndCopy(tokEmb, mType); + mPosEmb.convertAndCopy(posEmb, mType); + + copyToDevice(mGamma, sizeof(float) * mGamma.count, mGammaDev); + copyToDevice(mBeta, sizeof(float) * mBeta.count, mBetaDev); + copyToDevice(mWordEmb, getWeightsSize(mWordEmb, mType), mWordEmbDev); + copyToDevice(mPosEmb, getWeightsSize(mPosEmb, mType), mPosEmbDev); + copyToDevice(mTokEmb, getWeightsSize(mTokEmb, mType), mTokEmbDev); +} + +EmbLayerNormPluginDynamicLegacy::EmbLayerNormPluginDynamicLegacy( + std::string const& name, void const* data, size_t length) + : mLayerName(name) + , mGammaDev(nullptr) + , mBetaDev(nullptr) + , mWordEmbDev(nullptr) + , mTokEmbDev(nullptr) + , mPosEmbDev(nullptr) +{ + BERT_DEBUG_MSG("EmbLayerNormPluginDynamicLegacy deserialize."); + + // Deserialize in the same order as serialization + deserialize_value(&data, &length, &mType); + deserialize_value(&data, &length, &mMhaType); + deserialize_value(&data, &length, &mLd); + deserialize_value(&data, &length, &mS); + deserialize_value(&data, &length, &mWordVocabSize); + deserialize_value(&data, &length, &mPosVocabSize); + deserialize_value(&data, &length, &mTokVocabSize); + deserialize_value(&data, &length, &mUseFullMask); + deserialize_value(&data, &length, &mSM); + + char const* d = static_cast(data); + mBeta.convertAndCopy(d, mLd, nvinfer1::DataType::kFLOAT); + mGamma.convertAndCopy(d, mLd, nvinfer1::DataType::kFLOAT); + mWordEmb.convertAndCopy(d, mLd * mWordVocabSize, mType); + mPosEmb.convertAndCopy(d, mLd * mPosVocabSize, mType); + mTokEmb.convertAndCopy(d, mLd * mTokVocabSize, mType); + + copyToDevice(mGamma, sizeof(float) * mGamma.count, mGammaDev); + copyToDevice(mBeta, sizeof(float) * mBeta.count, mBetaDev); + copyToDevice(mWordEmb, getWeightsSize(mWordEmb, mType), mWordEmbDev); + copyToDevice(mPosEmb, getWeightsSize(mPosEmb, mType), mPosEmbDev); + copyToDevice(mTokEmb, getWeightsSize(mTokEmb, mType), mTokEmbDev); +} + +// IPluginV2DynamicExt Methods +IPluginV2DynamicExt* EmbLayerNormPluginDynamicLegacy::clone() const noexcept +{ + try + { + BERT_DEBUG_MSG("EmbLayerNormPluginDynamicLegacy clone."); + + auto p = new EmbLayerNormPluginDynamicLegacy( + mLayerName, mType, mMhaType, mBeta, mGamma, mWordEmb, mPosEmb, mTokEmb, mUseFullMask); + p->mS = mS; + p->setPluginNamespace(mNamespace.c_str()); + + return p; + } + catch (std::exception const& e) + { + caughtError(e); + } + return nullptr; +} + +DimsExprs EmbLayerNormPluginDynamicLegacy::getOutputDimensions( + int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept +{ + try + { + // Input should be input ids and token ids and the input mask + // Output should be the embeddings tensor and mask indices + PLUGIN_ASSERT(nbInputs == 3); + + PLUGIN_ASSERT(inputs[0].nbDims == 2); // BxS + PLUGIN_ASSERT(inputs[0].nbDims == inputs[1].nbDims); + PLUGIN_ASSERT(inputs[0].nbDims == inputs[2].nbDims); + + PLUGIN_ASSERT(outputIndex == 0 || outputIndex == 1); + + if (outputIndex == 0) + { + DimsExprs ret; + ret.nbDims = 5; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + ret.d[2] = exprBuilder.constant(mLd); + ret.d[3] = exprBuilder.constant(1); + ret.d[4] = exprBuilder.constant(1); + return ret; + } + + DimsExprs ret; + ret.nbDims = 2; + ret.d[0] = inputs[0].d[BDIM]; + auto cms0 = exprBuilder.constant(unfusedMaskSize); + + // this code must match getMHAMaskPackedSize in bertCommon.h + bool const isSmOK + = (mSM == kSM_75 || mSM == kSM_80 || mSM == kSM_86 || mSM == kSM_87 || mSM == kSM_89 || mSM == kSM_90); + bool const isPrecisionOK = (mMhaType == nvinfer1::DataType::kHALF || mMhaType == nvinfer1::DataType::kINT8); + if (mUseFullMask || (isSmOK && isPrecisionOK)) + { + // support 128, 384 in both int8 and fp16 + auto cms128 = exprBuilder.constant(packedMaskSize128); + auto cms384 = exprBuilder.constant(packedMaskSize384); + auto c128 = exprBuilder.constant(128); + auto c384 = exprBuilder.constant(384); + auto is128 = exprBuilder.operation(DimensionOperation::kEQUAL, *inputs[0].d[SDIM], *c128); + auto is384 = exprBuilder.operation(DimensionOperation::kEQUAL, *inputs[0].d[SDIM], *c384); + auto sel128 = exprBuilder.operation(DimensionOperation::kPROD, *is128, *cms128); + auto sel384 = exprBuilder.operation(DimensionOperation::kPROD, *is384, *cms384); + auto maskSize = exprBuilder.operation(DimensionOperation::kSUM, *sel384, *sel128); + + // support 64, 96 in both int8 and fp16 + auto cms64 = exprBuilder.constant(packedMaskSize64); + auto cms96 = exprBuilder.constant(packedMaskSize96); + auto c64 = exprBuilder.constant(64); + auto c96 = exprBuilder.constant(96); + + auto is64 = exprBuilder.operation(DimensionOperation::kEQUAL, *inputs[0].d[SDIM], *c64); + auto is96 = exprBuilder.operation(DimensionOperation::kEQUAL, *inputs[0].d[SDIM], *c96); + auto sel64 = exprBuilder.operation(DimensionOperation::kPROD, *is64, *cms64); + auto sel96 = exprBuilder.operation(DimensionOperation::kPROD, *is96, *cms96); + auto maskSize2 = exprBuilder.operation(DimensionOperation::kSUM, *sel64, *sel96); + maskSize = exprBuilder.operation(DimensionOperation::kSUM, *maskSize, *maskSize2); + + auto is0 = exprBuilder.operation(DimensionOperation::kEQUAL, *maskSize, *exprBuilder.constant(0)); + auto sel0 = exprBuilder.operation(DimensionOperation::kPROD, *is0, *cms0); + auto combinedMaskSize = exprBuilder.operation(DimensionOperation::kSUM, *maskSize, *sel0); + ret.d[1] = combinedMaskSize; + } + else + { + ret.d[1] = cms0; + } + + return ret; + } + catch (std::exception const& e) + { + caughtError(e); + } + return DimsExprs{}; +} + +bool EmbLayerNormPluginDynamicLegacy::supportsFormatCombination( + int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept +{ + // 3 inputs of size BxS + PLUGIN_ASSERT(nbInputs == 3); + PLUGIN_ASSERT(nbOutputs == 2); + + PluginTensorDesc const& desc = inOut[pos]; + if (desc.format != TensorFormat::kLINEAR) + { + return false; + } + if (pos == 0) + { + return desc.type == DataType::kINT32 && desc.dims.nbDims == 2; + } + + PluginTensorDesc const& prev = inOut[pos - 1]; + if (pos == 1 || pos == 2) + { + return desc.type == DataType::kINT32 && desc.dims.nbDims == 2 && desc.dims.d[BDIM] == prev.dims.d[BDIM] + && desc.dims.d[SDIM] == prev.dims.d[SDIM]; + } + + // embedded sequence + if (pos == 3) + { + return desc.type == mType && desc.dims.nbDims == 5 && desc.dims.d[BDIM] == prev.dims.d[BDIM] + && desc.dims.d[SDIM] == prev.dims.d[SDIM] && desc.dims.d[3] == 1 && desc.dims.d[4] == 1; + } + // mask + return desc.type == DataType::kINT32; +} + +void EmbLayerNormPluginDynamicLegacy::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs, + DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept +{ + BERT_DEBUG_MSG("EmbLayerNormPluginDynamicLegacy configurePlugin."); + + // Validate input arguments + PLUGIN_ASSERT(nbOutputs == 2); + PLUGIN_ASSERT(nbInputs == 3); + + PLUGIN_ASSERT(inputs[0].desc.dims.nbDims == 2); + int32_t const S = inputs[0].desc.dims.d[SDIM]; + mS = S; + int32_t const B = inputs[0].desc.dims.d[BDIM]; + TRT_UNUSED B; + PLUGIN_ASSERT(mS == static_cast(inputs[1].desc.dims.d[SDIM])); + PLUGIN_ASSERT(B == inputs[1].desc.dims.d[BDIM]); + PLUGIN_ASSERT(mS == static_cast(inputs[2].desc.dims.d[SDIM])); + PLUGIN_ASSERT(B == inputs[2].desc.dims.d[BDIM]); + + PLUGIN_ASSERT(outputs[0].desc.dims.nbDims == 5); + PLUGIN_ASSERT(static_cast(outputs[0].desc.dims.d[SDIM]) == mS); + PLUGIN_ASSERT(outputs[0].desc.dims.d[BDIM] == B); + PLUGIN_ASSERT(static_cast(outputs[0].desc.dims.d[2]) == mLd); + PLUGIN_ASSERT(outputs[0].desc.dims.d[3] == 1); + PLUGIN_ASSERT(outputs[0].desc.dims.d[4] == 1); + + if (mUseFullMask) + { + // user force full_mask + PLUGIN_ASSERT(outputs[1].desc.dims.nbDims == 2); + PLUGIN_ASSERT(outputs[1].desc.dims.d[0] == B); + PLUGIN_ASSERT((outputs[1].desc.dims.d[1] == -1) || (outputs[1].desc.dims.d[1] == packedMaskSize384) + || (outputs[1].desc.dims.d[1] == packedMaskSize128)); + } + else + { + // auto detect using mhatype + if (S != -1 && B != -1) + { + PLUGIN_ASSERT(outputs[1].desc.dims.nbDims == 2); + PLUGIN_ASSERT(outputs[1].desc.dims.d[0] == B); + int32_t packedSize = getMHAMaskPackedSize(mSM, mMhaType, S); + TRT_UNUSED packedSize; + PLUGIN_ASSERT(outputs[1].desc.dims.d[1] == -1 || outputs[1].desc.dims.d[1] == packedSize); + } + } + + PLUGIN_ASSERT(inputs[0].desc.type == DataType::kINT32); + PLUGIN_ASSERT(inputs[1].desc.type == DataType::kINT32); + PLUGIN_ASSERT(inputs[2].desc.type == DataType::kINT32); + PLUGIN_ASSERT(outputs[0].desc.type == mType); + PLUGIN_ASSERT(outputs[1].desc.type == DataType::kINT32); +} + +size_t EmbLayerNormPluginDynamicLegacy::getWorkspaceSize( + PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept +{ + return 0; +} + +int32_t EmbLayerNormPluginDynamicLegacy::enqueue(PluginTensorDesc const* inputDesc, + PluginTensorDesc const* /* outputDesc */, void const* const* inputs, void* const* outputs, void* /* workspace */, + cudaStream_t stream) noexcept +{ + try + { + PLUGIN_VALIDATE(inputDesc != nullptr && inputs != nullptr && outputs != nullptr); + + int32_t const batchSize = inputDesc->dims.d[BDIM]; + int32_t const S = inputDesc->dims.d[SDIM]; + int32_t status = STATUS_FAILURE; + + // Our plugin outputs only one tensor + auto const inputIds = static_cast(inputs[0]); + auto const segmentIds = static_cast(inputs[1]); + auto const inputMask = static_cast(inputs[2]); + + float const* beta = mBetaDev.get(); + float const* gamma = mGammaDev.get(); + if (mType == DataType::kFLOAT) + { + auto output = static_cast(outputs[0]); + auto const wordEmb = static_cast(mWordEmbDev.get()); + auto const tokEmb = static_cast(mTokEmbDev.get()); + auto const posEmb = static_cast(mPosEmbDev.get()); + status = embSkipLayerNorm(stream, static_cast(mLd), batchSize, S, inputIds, segmentIds, + beta, gamma, wordEmb, posEmb, tokEmb, mWordVocabSize, mTokVocabSize, output); + + if (status != cudaSuccess) + { + return status; + } + } + else if (mType == DataType::kHALF) + { + auto output = static_cast(outputs[0]); + auto const wordEmb = static_cast(mWordEmbDev.get()); + auto const tokEmb = static_cast(mTokEmbDev.get()); + auto const posEmb = static_cast(mPosEmbDev.get()); + status = embSkipLayerNorm(stream, static_cast(mLd), batchSize, S, inputIds, segmentIds, beta, + gamma, wordEmb, posEmb, tokEmb, mWordVocabSize, mTokVocabSize, output); + + if (status != cudaSuccess) + { + return status; + } + } + else + { + gLogError << "Unsupported type error, expected [kHALF,kFLOAT], but received " << static_cast(mType) + << std::endl; + + return STATUS_NOT_SUPPORTED; + } + + // check mha use fused kernel + if (mUseFullMask || unfusedMaskSize != getMHAMaskPackedSize(mSM, mMhaType, S)) + { + size_t warps_m = 0, warps_n = 0, warps_k = 1; + if (S == 64 || S == 96 || S == 128) + { + warps_m = 2; + warps_n = 2; + } + else if (S == 384) + { + warps_m = 1; + warps_n = 8; + } + uint32_t* inputMaskX = static_cast(outputs[1]); + + status = convertMask(S, batchSize, warps_m, warps_n, warps_k, inputMask, inputMaskX, stream); + } + else + { + int32_t* maskIdx = static_cast(outputs[1]); + status = computeMaskIdx(stream, S, batchSize, inputMask, maskIdx); + } + + return status; + } + catch (std::exception const& e) + { + caughtError(e); + } + return STATUS_FAILURE; +} + +// IPluginV2Ext Methods +DataType EmbLayerNormPluginDynamicLegacy::getOutputDataType( + int32_t index, DataType const* inputTypes, int32_t nbInputs) const noexcept +{ + + PLUGIN_ASSERT(index == 0 || index == 1); + if (index == 0) + { + PLUGIN_ASSERT(mType == DataType::kHALF || mType == DataType::kFLOAT); + return mType; + } + return DataType::kINT32; +} + +// IPluginV2 Methods +char const* EmbLayerNormPluginDynamicLegacy::getPluginType() const noexcept +{ + return gEmbLayerNormName; +} + +char const* EmbLayerNormPluginDynamicLegacy::getPluginVersion() const noexcept +{ + return gEmbLayerNormVersion; +} + +int32_t EmbLayerNormPluginDynamicLegacy::getNbOutputs() const noexcept +{ + return 2; +} + +int32_t EmbLayerNormPluginDynamicLegacy::initialize() noexcept +{ + return 0; +} + +void EmbLayerNormPluginDynamicLegacy::terminate() noexcept +{ + BERT_DEBUG_MSG("EmbLayerNormPluginDynamicLegacy terminate."); +} + +size_t EmbLayerNormPluginDynamicLegacy::getSerializationSize() const noexcept +{ + size_t const wordSize = getElementSize(mType); + return sizeof(mType) // type + + sizeof(mMhaType) // mha plugin datatype + + sizeof(mLd) * 5 // mLd, mS, m*VocabSize + + sizeof(mUseFullMask) // mask type + + sizeof(mSM) // smversion + + 2 * sizeof(float) * mLd // beta + gamma + + wordSize * mLd * mWordVocabSize // word emb + + wordSize * mLd * mPosVocabSize // pos emb + + wordSize * mLd * mTokVocabSize // tok emb + ; +} + +void EmbLayerNormPluginDynamicLegacy::serialize(void* buffer) const noexcept +{ + serialize_value(&buffer, mType); + serialize_value(&buffer, mMhaType); + serialize_value(&buffer, mLd); + serialize_value(&buffer, mS); + serialize_value(&buffer, mWordVocabSize); + serialize_value(&buffer, mPosVocabSize); + serialize_value(&buffer, mTokVocabSize); + serialize_value(&buffer, mUseFullMask); + serialize_value(&buffer, mSM); + + char* d = static_cast(buffer); + serFromDev(d, mBetaDev.get(), mLd); + serFromDev(d, mGammaDev.get(), mLd); + size_t const wordSize = getElementSize(mType); + serFromDev(d, static_cast(mWordEmbDev.get()), mLd * mWordVocabSize * wordSize); + serFromDev(d, static_cast(mPosEmbDev.get()), mLd * mPosVocabSize * wordSize); + serFromDev(d, static_cast(mTokEmbDev.get()), mLd * mTokVocabSize * wordSize); +} + +void EmbLayerNormPluginDynamicLegacy::destroy() noexcept +{ + BERT_DEBUG_MSG("EmbLayerNormPluginDynamicLegacy destroy."); + // This gets called when the network containing plugin is destroyed + mGammaDev.reset(nullptr); + mBetaDev.reset(nullptr); + mWordEmbDev.reset(nullptr); + mPosEmbDev.reset(nullptr); + mTokEmbDev.reset(nullptr); + delete this; +} + +void EmbLayerNormPluginDynamicLegacy::setPluginNamespace(char const* libNamespace) noexcept +{ + try + { + mNamespace = libNamespace; + } + catch (std::exception const& e) + { + caughtError(e); + } +} + +char const* EmbLayerNormPluginDynamicLegacy::getPluginNamespace() const noexcept +{ + return mNamespace.c_str(); +} + +/////////////////////// + +EmbLayerNormPluginDynamicLegacyCreator::EmbLayerNormPluginDynamicLegacyCreator() +{ + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(PluginField("bert_embeddings_layernorm_beta")); + mPluginAttributes.emplace_back(PluginField("bert_embeddings_layernorm_gamma")); + mPluginAttributes.emplace_back(PluginField("bert_embeddings_word_embeddings")); + mPluginAttributes.emplace_back(PluginField("bert_embeddings_token_type_embeddings")); + mPluginAttributes.emplace_back(PluginField("bert_embeddings_position_embeddings")); + mPluginAttributes.emplace_back(PluginField("output_fp16")); + mPluginAttributes.emplace_back(PluginField("full_mask")); + mPluginAttributes.emplace_back(PluginField("mha_type_id")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +char const* EmbLayerNormPluginDynamicLegacyCreator::getPluginName() const noexcept +{ + return gEmbLayerNormName; +} + +char const* EmbLayerNormPluginDynamicLegacyCreator::getPluginVersion() const noexcept +{ + return gEmbLayerNormVersion; +} + +PluginFieldCollection const* EmbLayerNormPluginDynamicLegacyCreator::getFieldNames() noexcept +{ + return &mFC; +} + +IPluginV2* EmbLayerNormPluginDynamicLegacyCreator::createPlugin( + char const* name, PluginFieldCollection const* fc) noexcept +{ + try + { + BERT_DEBUG_MSG("EmbLayerNormPluginDynamicLegacy createPlugin."); + + bool output_fp16 = false; + bool useFullMask = false; + Weights beta{}; // required attribute - validateRequiredAttributesExist() will verify existence + Weights gamma{}; // required attribute - validateRequiredAttributesExist() will verify existence + Weights word_emb{}; // required attribute - validateRequiredAttributesExist() will verify existence + Weights pos_emb{}; // required attribute - validateRequiredAttributesExist() will verify existence + Weights tok_emb{}; // required attribute - validateRequiredAttributesExist() will verify existence + int32_t mhaTypeId = 0; + std::set const requiredAttributes{ + "bert_embeddings_layernorm_beta", + "bert_embeddings_layernorm_gamma", + "bert_embeddings_word_embeddings", + "bert_embeddings_token_type_embeddings", + "bert_embeddings_position_embeddings", + }; + plugin::validateRequiredAttributesExist(requiredAttributes, fc); + + for (int32_t i = 0; i < fc->nbFields; i++) + { + std::string field_name(fc->fields[i].name); + if (field_name.compare("bert_embeddings_layernorm_beta") == 0) + { + BERT_DEBUG_MSG("Building bert_embeddings_layernorm_beta..."); + beta.values = fc->fields[i].data; + beta.count = fc->fields[i].length; + beta.type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("bert_embeddings_layernorm_gamma") == 0) + { + BERT_DEBUG_MSG("Building bert_embeddings_layernorm_gamma..."); + gamma.values = fc->fields[i].data; + gamma.count = fc->fields[i].length; + gamma.type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("bert_embeddings_word_embeddings") == 0) + { + BERT_DEBUG_MSG("Building bert_embeddings_word_embeddings..."); + word_emb.values = fc->fields[i].data; + word_emb.count = fc->fields[i].length; + word_emb.type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("bert_embeddings_token_type_embeddings") == 0) + { + BERT_DEBUG_MSG("Building bert_embeddings_token_type_embeddings..."); + tok_emb.values = fc->fields[i].data; + tok_emb.count = fc->fields[i].length; + tok_emb.type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("bert_embeddings_position_embeddings") == 0) + { + BERT_DEBUG_MSG("Building bert_embeddings_position_embeddings..."); + pos_emb.values = fc->fields[i].data; + pos_emb.count = fc->fields[i].length; + pos_emb.type = fieldTypeToDataType(fc->fields[i].type); + } + if (field_name.compare("output_fp16") == 0) + { + BERT_DEBUG_MSG("Building output_fp16..."); + PLUGIN_VALIDATE(fc->fields[i].type == PluginFieldType::kINT32); + output_fp16 = static_cast(fc->fields[i].data)[0] != 0; + } + if (field_name.compare("full_mask") == 0) + { + BERT_DEBUG_MSG("Building full_mask..."); + PLUGIN_VALIDATE(fc->fields[i].type == PluginFieldType::kINT32); + useFullMask = static_cast(fc->fields[i].data)[0] != 0; + } + if (field_name.compare("mha_type_id") == 0) + { + mhaTypeId = *static_cast(fc->fields[i].data); + PLUGIN_VALIDATE(mhaTypeId >= 0 && mhaTypeId <= 3); + BERT_DEBUG_VALUE("Building mha typeId: ", mhaTypeId); + } + } + + BERT_DEBUG_MSG("Building the Plugin..."); + DataType mhaType = static_cast(mhaTypeId); + EmbLayerNormPluginDynamicLegacy* p + = new EmbLayerNormPluginDynamicLegacy(name, output_fp16 ? DataType::kHALF : DataType::kFLOAT, mhaType, beta, + gamma, word_emb, pos_emb, tok_emb, useFullMask); + return p; + } + catch (std::exception const& e) + { + caughtError(e); + } + return nullptr; +} + +IPluginV2* EmbLayerNormPluginDynamicLegacyCreator::deserializePlugin( + char const* name, void const* serialData, size_t serialLength) noexcept +{ + try + { + // This object will be deleted when the network is destroyed, which will + // call EmbLayerNormPluginDynamicLegacy::destroy() + return new EmbLayerNormPluginDynamicLegacy(name, serialData, serialLength); + } + catch (std::exception const& e) + { + caughtError(e); + } + return nullptr; +} + +void EmbLayerNormPluginDynamicLegacyCreator::setPluginNamespace(char const* libNamespace) noexcept +{ + try + { + mNamespace = libNamespace; + } + catch (std::exception const& e) + { + caughtError(e); + } +} + +char const* EmbLayerNormPluginDynamicLegacyCreator::getPluginNamespace() const noexcept +{ + return mNamespace.c_str(); +} + +#endif // CUDA_VERSION >= 10010 diff --git a/plugin/embLayerNormPlugin/embLayerNormPluginLegacy.h b/plugin/embLayerNormPlugin/embLayerNormPluginLegacy.h new file mode 100644 index 000000000..cb8f8165e --- /dev/null +++ b/plugin/embLayerNormPlugin/embLayerNormPluginLegacy.h @@ -0,0 +1,151 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#if CUDA_VERSION >= 10010 + +#ifndef TRT_EMB_LAYER_NORM_PLUGIN_LEGACY_H +#define TRT_EMB_LAYER_NORM_PLUGIN_LEGACY_H + +#include "NvInferPlugin.h" +#include "NvInferRuntime.h" + +#include "common/bertCommon.h" +#include +#include + +namespace nvinfer1 +{ +namespace plugin +{ +namespace bert +{ + +int32_t computeMaskIdx(cudaStream_t stream, int32_t const S, int32_t const B, int32_t const* mask, int32_t* maskIdx); + +template +int32_t embSkipLayerNorm(cudaStream_t stream, int32_t ld, int32_t B, int32_t S, int32_t const* inputIds, + int32_t const* token_ids, float const* beta, float const* gamma, T const* wordEmb, T const* posEmb, T const* tokEmb, + int32_t const wordSize, int32_t const tokSize, T* output); + +cudaError_t convertMask(uint32_t const S, uint32_t const B, uint32_t const warps_m, uint32_t const warps_n, + uint32_t const warps_k, int32_t const* inputMaskSB, uint32_t* inputMaskX, cudaStream_t stream); + +class EmbLayerNormPluginDynamicLegacy : public nvinfer1::IPluginV2DynamicExt +{ +public: + EmbLayerNormPluginDynamicLegacy(std::string const& name, nvinfer1::DataType const type, + nvinfer1::DataType const mhaType, nvinfer1::Weights const& beta, nvinfer1::Weights const& gamma, + nvinfer1::Weights const& word_emb, nvinfer1::Weights const& pos_emb, nvinfer1::Weights const& tok_emb, + bool const useFullMask); + + EmbLayerNormPluginDynamicLegacy(std::string const& name, void const* data, size_t length); + + // It doesn't make sense to make EmbLayerNormPluginDynamicLegacy without arguments, so we + // delete default constructor. + EmbLayerNormPluginDynamicLegacy() = delete; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination( + int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType( + int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override; + + // IPluginV2 Methods + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; + int32_t getNbOutputs() const noexcept override; + int32_t initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + +private: + std::string const mLayerName; + std::string mNamespace; + + bert::cuda_unique_ptr mGammaDev; + bert::cuda_unique_ptr mBetaDev; + bert::cuda_unique_ptr mWordEmbDev; + bert::cuda_unique_ptr mTokEmbDev; + bert::cuda_unique_ptr mPosEmbDev; + size_t mLd; // leading dim = hidden size + size_t mS; // sequence length + size_t mWordVocabSize; + size_t mPosVocabSize; + size_t mTokVocabSize; + bert::WeightsWithOwnership mBeta; + bert::WeightsWithOwnership mGamma; + bert::WeightsWithOwnership mWordEmb; + bert::WeightsWithOwnership mTokEmb; + bert::WeightsWithOwnership mPosEmb; + nvinfer1::DataType mType; + bool mUseFullMask; + nvinfer1::DataType mMhaType; + int32_t mSM; + + using IPluginV2::getOutputDimensions; + using IPluginV2::getWorkspaceSize; + using IPluginV2::enqueue; + using IPluginV2Ext::configurePlugin; +}; + +class EmbLayerNormPluginDynamicLegacyCreator : public nvinfer1::IPluginCreator +{ +public: + EmbLayerNormPluginDynamicLegacyCreator(); + + char const* getPluginName() const noexcept override; + + char const* getPluginVersion() const noexcept override; + + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; + + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; + + nvinfer1::IPluginV2* deserializePlugin( + char const* name, void const* serialData, size_t serialLength) noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept override; + + char const* getPluginNamespace() const noexcept override; + +private: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; +} // namespace bert +} // namespace plugin +} // namespace nvinfer1 +#endif // TRT_EMB_LAYER_NORM_PLUGIN_LEGACY_H + +#endif // CUDA_VERSION >= 10010 diff --git a/plugin/fcPlugin/README.md b/plugin/fcPlugin/README.md index 913dfead2..500bc05f2 100644 --- a/plugin/fcPlugin/README.md +++ b/plugin/fcPlugin/README.md @@ -11,6 +11,8 @@ ## Description +> NOTE: This plugin is deprecated since TensorRT 10.6. Its functionality has been superseded by the [`IMatrixMultiplyLayer`](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_matrix_multiply_layer.html) (Can be added to the network definition using [`addMatrixMultiply()`](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_network_definition.html#acf109d93e91c86afbd263f5fea29ffe8)) + Performs a matrix multiplication similar to the FullyConnected Layer in TensorRT, but without bias. The main difference is that the weights are not transposed. Always dispatches to cuBLAS. At engine build time, the plugin runs a search over the parameters of the available algorithms to find the fastest one available. @@ -51,8 +53,8 @@ documentation. ## Changelog -November 2019 -This is the first release of this `README.md` file. +- October 2024: Add deprecation note. +- November 2019: This is the first release of this `README.md` file. ## Known issues diff --git a/python/docstrings/infer/pyCoreDoc.h b/python/docstrings/infer/pyCoreDoc.h index 541e66b82..ceb4f3631 100644 --- a/python/docstrings/infer/pyCoreDoc.h +++ b/python/docstrings/infer/pyCoreDoc.h @@ -713,7 +713,8 @@ constexpr char const* descr = R"trtdoc( :ivar streamable_weights_size: Returns the size of the streamable weights in the engine. This may not include all the weights. :ivar weight_streaming_budget_v2: Set and get the current weight streaming budget for inference. The budget may be set any non-negative value. A value of 0 streams the most weights. Values equal to streamable_weights_size (default) or larger will disable weight streaming. :ivar weight_streaming_scratch_memory_size: The amount of scratch memory required by a TensorRT ExecutionContext to perform inference. This value may change based on the current weight streaming budget. Please use the V2 memory APIs, engine.device_memory_size_v2 and ExecutionContext.set_device_memory() to provide memory which includes the current weight streaming scratch memory. Not specifying these APIs or using the V1 APIs will not include this memory, so TensorRT will resort to allocating itself. - )trtdoc"; + )trtdoc" + ; // Documentation bug with parameters on these three functions because they are overloaded. constexpr char const* serialize = R"trtdoc( @@ -955,13 +956,13 @@ constexpr char const* read = R"trtdoc( If an allocation request cannot be satisfied, ``0`` should be returned. - :arg destination: The host memory address to copy read memory to. :arg size: The number of bytes required. - :returns: The number of bytes read. + :returns: A buffer containing the bytes read. )trtdoc"; } // namespace StreamReaderDoc + namespace BuilderFlagDoc { constexpr char const* descr @@ -1010,6 +1011,9 @@ constexpr char const* REFIT_INDIVIDUAL constexpr char const* WEIGHT_STREAMING = R"trtdoc(Enable building with the ability to stream varying amounts of weights during Runtime. This decreases GPU memory of TRT at the expense of performance.)trtdoc"; constexpr char const* INT4 = R"trtdoc(Enable plugins with INT4 input/output)trtdoc"; +constexpr char const* STRICT_NANS + = R"trtdoc(Disable floating-point optimizations: 0*x => 0, x-x => 0, or x/x => 1. These identities are not true when x is a NaN or Inf, and thus might hide propagation or generation of NaNs.)trtdoc"; +constexpr char const* MONITOR_MEMORY = R"trtdoc(Enable memory monitor during build time.)trtdoc"; } // namespace BuilderFlagDoc namespace MemoryPoolTypeDoc @@ -1599,6 +1603,17 @@ constexpr char const* build_serialized_network = R"trtdoc( :returns: A pointer to a :class:`IHostMemory` object that contains a serialized network. )trtdoc"; +constexpr char const* build_engine_with_config = R"trtdoc( + Builds a network for the given :class:`INetworkDefinition` and :class:`IBuilderConfig` . + + This function allows building a network and creating an engine. + + :arg network: Network definition. + :arg config: Builder configuration. + + :returns: A pointer to a :class:`ICudaEngine` object that contains a built engine. +)trtdoc"; + constexpr char const* is_network_supported = R"trtdoc( Checks that a network is within the scope of the :class:`IBuilderConfig` settings. @@ -1664,7 +1679,8 @@ constexpr char const* deserialize_cuda_engine = R"trtdoc( constexpr char const* deserialize_cuda_engine_reader = R"trtdoc( Deserialize an :class:`ICudaEngine` from a stream reader. - :arg stream_reader: The :class:`PyStreamReader` that will read the serialized :class:`ICudaEngine`. This enables deserialization from a file directly. + :arg stream_reader: The :class:`PyStreamReader` that will read the serialized :class:`ICudaEngine`. This enables + deserialization from a file directly. :returns: The :class:`ICudaEngine`, or None if it could not be deserialized. )trtdoc"; diff --git a/python/docstrings/infer/pyGraphDoc.h b/python/docstrings/infer/pyGraphDoc.h index 74a651d38..32a230682 100644 --- a/python/docstrings/infer/pyGraphDoc.h +++ b/python/docstrings/infer/pyGraphDoc.h @@ -198,7 +198,8 @@ constexpr const char* descr = R"trtdoc( :ivar dynamic_range: :class:`Tuple[float, float]` [DEPRECATED] Deprecated in TensorRT 10.1. Superseded by explicit quantization. A tuple containing the [minimum, maximum] of the dynamic range, or :class:`None` if the range was not set. :ivar is_shape: :class:`bool` Whether the tensor is a shape tensor. :ivar allowed_formats: :class:`int32` The allowed set of TensorFormat candidates. This should be an integer consisting of one or more :class:`TensorFormat` s, combined via bitwise OR after bit shifting. For example, ``1 << int(TensorFormat.CHW4) | 1 << int(TensorFormat.CHW32)``. -)trtdoc"; +)trtdoc" + ; constexpr const char* set_dynamic_range = R"trtdoc( [DEPRECATED] Deprecated in TensorRT 10.1. Superseded by explicit quantization. @@ -1653,6 +1654,7 @@ constexpr const char* descr = R"trtdoc( )trtdoc"; } // namespace IDequantizeLayerDoc + namespace IIfConditionalBoundaryLayerDoc { constexpr const char* descr = R"trtdoc( @@ -2372,6 +2374,16 @@ constexpr const char* add_plugin_v2 = R"trtdoc( :returns: The new plugin layer, or :class:`None` if it could not be created. )trtdoc"; +constexpr const char* add_plugin = R"trtdoc( + Add a plugin layer to the network with a tuple of (inputs, shape_inputs, plugin). :func:`add_plugin_v3` can be thought of as an "unpacked tuple" version of this function. + + Primarily intended to be used when using the `tensorrt.plugin` module to implement the plugin. + + :arg tuple: A tuple of (inputs, shape_inputs, plugin). + + :returns: The new plugin layer, or :class:`None` if it could not be created. +)trtdoc"; + constexpr const char* add_plugin_v3 = R"trtdoc( Add a plugin layer to the network using an :class:`IPluginV3` interface. See :class:`IPluginV3` for more information. @@ -2441,6 +2453,7 @@ constexpr const char* add_dequantize = R"trtdoc( :returns: The new dequantization layer, or :class:`None` if it could not be created. )trtdoc"; + constexpr const char* add_if_conditional = R"trtdoc( Adds an if-conditional to the network, which provides a way to specify subgraphs that will be conditionally executed using lazy evaluation. See :class:`IIfConditional` for more information. diff --git a/python/include/impl/plugin.h b/python/include/impl/plugin.h new file mode 100644 index 000000000..0c22e297e --- /dev/null +++ b/python/include/impl/plugin.h @@ -0,0 +1,290 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRT_PYTHON_IMPL_PLUGIN_H +#define TRT_PYTHON_IMPL_PLUGIN_H + +#include "NvInfer.h" + +//! +//! \file plugin.h +//! +//! This file contains definitions for supporting the `tensorrt.plugin` Python module +//! +//! \warning None of the defintions here are part of the TensorRT C++ API and may not follow semantic versioning rules. +//! TensorRT clients must not utilize them directly. +//! + +namespace nvinfer1 +{ +namespace v_1_0 +{ + +class IPluginV3QuickCore : public IPluginCapability +{ +public: + InterfaceInfo getInterfaceInfo() const noexcept override + { + return InterfaceInfo{"PLUGIN_V3QUICK_CORE", 1, 0}; + } + + virtual AsciiChar const* getPluginName() const noexcept = 0; + + virtual AsciiChar const* getPluginVersion() const noexcept = 0; + + virtual AsciiChar const* getPluginNamespace() const noexcept = 0; +}; + +class IPluginV3QuickBuild : public IPluginCapability +{ +public: + InterfaceInfo getInterfaceInfo() const noexcept override + { + return InterfaceInfo{"PLUGIN_V3QUICK_BUILD", 1, 0}; + } + + //! + //! \brief Provide the data types of the plugin outputs if the input tensors have the data types provided. + //! + //! \param outputTypes Pre-allocated array to which the output data types should be written. + //! \param nbOutputs The number of output tensors. This matches the value returned from getNbOutputs(). + //! \param inputTypes The input data types. + //! \param inputRanks Ranks of the input tensors + //! \param nbInputs The number of input tensors. + //! + //! \return 0 for success, else non-zero + //! + virtual int32_t getOutputDataTypes(DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, + int32_t const* inputRanks, int32_t nbInputs) const noexcept = 0; + + //! + //! \brief Provide expressions for computing dimensions of the output tensors from dimensions of the input tensors. + //! + //! \param inputs Expressions for dimensions of the input tensors + //! \param nbInputs The number of input tensors + //! \param shapeInputs Expressions for values of the shape tensor inputs + //! \param nbShapeInputs The number of shape tensor inputs + //! \param outputs Pre-allocated array to which the output dimensions must be written + //! \param exprBuilder Object for generating new dimension expressions + //! + //! \return 0 for success, else non-zero + //! + virtual int32_t getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs, + int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept = 0; + + //! + //! \brief Configure the plugin. Behaves similarly to `IPluginV3OneBuild::configurePlugin()` + //! + //! \return 0 for success, else non-zero + //! + virtual int32_t configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, + DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept = 0; + + //! + //! \brief Get number of format combinations supported by the plugin for the I/O characteristics indicated by + //! `inOut`. + //! + virtual int32_t getNbSupportedFormatCombinations( + DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept = 0; + + //! + //! \brief Write all format combinations supported by the plugin for the I/O characteristics indicated by `inOut` to + //! `supportedCombinations`. It is guaranteed to have sufficient memory allocated for (nbInputs + nbOutputs) * + //! getNbSupportedFormatCombinations() `PluginTensorDesc`s. + //! + //! \return 0 for success, else non-zero + //! + virtual int32_t getSupportedFormatCombinations(DynamicPluginTensorDesc const* inOut, int32_t nbInputs, + int32_t nbOutputs, PluginTensorDesc* supportedCombinations, int32_t nbFormatCombinations) noexcept = 0; + + //! + //! \brief Get the number of outputs from the plugin. + //! + virtual int32_t getNbOutputs() const noexcept = 0; + + //! + //! \brief Communicates to TensorRT that the output at the specified output index is aliased to the input at the + //! returned index. Behaves similary to `v_2_0::IPluginV3OneBuild.getAliasedInput()`. + //! + virtual int32_t getAliasedInput(int32_t outputIndex) noexcept + { + return -1; + } + + //! + //! \brief Query for any custom tactics that the plugin intends to use specific to the I/O characteristics indicated + //! by the immediately preceding call to `configurePlugin()`. + //! + //! \return 0 for success, else non-zero + //! + virtual int32_t getValidTactics(int32_t* tactics, int32_t nbTactics) noexcept + { + return 0; + } + + //! + //! \brief Query for number of custom tactics related to the `getValidTactics()` call. + //! + virtual int32_t getNbTactics() noexcept + { + return 0; + } + + //! + //! \brief Called to query the suffix to use for the timing cache ID. May be called anytime after plugin creation. + //! + virtual char const* getTimingCacheID() noexcept + { + return nullptr; + } + + //! + //! \brief Query for a string representing the configuration of the plugin. May be called anytime after + //! plugin creation. + //! + virtual char const* getMetadataString() noexcept + { + return nullptr; + } +}; + +class IPluginV3QuickRuntime : public IPluginCapability +{ +public: + InterfaceInfo getInterfaceInfo() const noexcept override + { + return InterfaceInfo{"PLUGIN_V3QUICK_RUNTIME", 1, 0}; + } + + //! + //! \brief Set the tactic to be used in the subsequent call to enqueue(). Behaves similar to + //! `IPluginV3OneRuntime::setTactic()`. + //! + //! \return 0 for success, else non-zero + //! + virtual int32_t setTactic(int32_t tactic) noexcept + { + return 0; + } + + //! + //! \brief Execute the plugin. + //! + //! \param inputDesc how to interpret the memory for the input tensors. + //! \param outputDesc how to interpret the memory for the output tensors. + //! \param inputs The memory for the input tensors. + //! \param inputStrides Strides for input tensors. + //! \param outputStrides Strides for output tensors. + //! \param outputs The memory for the output tensors. + //! \param nbInputs Number of input tensors. + //! \param nbOutputs Number of output tensors. + //! \param stream The stream in which to execute the kernels. + //! + //! \return 0 for success, else non-zero + //! + virtual int32_t enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, Dims const* inputStrides, Dims const* outputStrides, + int32_t nbInputs, int32_t nbOutputs, cudaStream_t stream) noexcept = 0; + + //! + //! \brief Get the plugin fields which should be serialized. + //! + virtual PluginFieldCollection const* getFieldsToSerialize() noexcept = 0; +}; + +class IPluginCreatorV3Quick : public IPluginCreatorInterface +{ +public: + InterfaceInfo getInterfaceInfo() const noexcept override + { + return InterfaceInfo{"PLUGIN CREATOR_V3QUICK", 1, 0}; + } + + //! + //! \brief Return a plugin object. Return nullptr in case of error. + //! + //! \param name A NULL-terminated name string of length 1024 or less, including the NULL terminator. + //! \param namespace A NULL-terminated name string of length 1024 or less, including the NULL terminator. + //! \param fc A pointer to a collection of fields needed for constructing the plugin. + //! \param phase The TensorRT phase in which the plugin is being created + //! + virtual IPluginV3* createPlugin(AsciiChar const* name, AsciiChar const* nspace, PluginFieldCollection const* fc, + TensorRTPhase phase) noexcept = 0; + + //! + //! \brief Return a list of fields that need to be passed to createPlugin() when creating a plugin for use in the + //! TensorRT build phase. + //! + virtual PluginFieldCollection const* getFieldNames() noexcept = 0; + + virtual AsciiChar const* getPluginName() const noexcept = 0; + + virtual AsciiChar const* getPluginVersion() const noexcept = 0; + + virtual AsciiChar const* getPluginNamespace() const noexcept = 0; + + IPluginCreatorV3Quick() = default; + virtual ~IPluginCreatorV3Quick() = default; + +protected: + IPluginCreatorV3Quick(IPluginCreatorV3Quick const&) = default; + IPluginCreatorV3Quick(IPluginCreatorV3Quick&&) = default; + IPluginCreatorV3Quick& operator=(IPluginCreatorV3Quick const&) & = default; + IPluginCreatorV3Quick& operator=(IPluginCreatorV3Quick&&) & = default; +}; + +} // namespace v_1_0 + +//! +//! \class IPluginV3QuickCore +//! +//! \brief Provides core capability (`IPluginCapability::kCORE` for quickly-deployable TRT plugins) +//! +//! \warning This class is strictly for the purpose of supporting quickly-deployable TRT Python plugins and is not part +//! of the public TensorRT C++ API. Users must not inherit from this class. +//! +using IPluginV3QuickCore = v_1_0::IPluginV3QuickCore; + +//! +//! \class IPluginV3QuickBuild +//! +//! \brief Provides build capability (`IPluginCapability::kBUILD` for quickly-deployable TRT plugins) +//! +//! \warning This class is strictly for the purpose of supporting quickly-deployable TRT Python plugins and is not part +//! of the public TensorRT C++ API. Users must not inherit from this class. +//! +using IPluginV3QuickBuild = v_1_0::IPluginV3QuickBuild; + +//! +//! \class IPluginV3QuickRuntime +//! +//! \warning This class is strictly for the purpose of supporting quickly-deployable TRT Python plugins and is not part +//! of the public TensorRT C++ API. Users must not inherit from this class. +//! +using IPluginV3QuickRuntime = v_1_0::IPluginV3QuickRuntime; + +//! +//! \class IPluginCreatorV3Quick +//! +//! \warning This class is strictly for the purpose of supporting quickly-deployable TRT Python plugins and is not part +//! of the public TensorRT C++ API. Users must not inherit from this class. +//! +using IPluginCreatorV3Quick = v_1_0::IPluginCreatorV3Quick; + +} // namespace nvinfer1 + +#endif // TRT_PYTHON_IMPL_PLUGIN_H diff --git a/python/packaging/bindings_wheel/setup.cfg b/python/packaging/bindings_wheel/setup.cfg index 9e20a94e5..b6f5905e3 100644 --- a/python/packaging/bindings_wheel/setup.cfg +++ b/python/packaging/bindings_wheel/setup.cfg @@ -1,12 +1,19 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# [metadata] license_files = LICENSE.txt diff --git a/python/packaging/bindings_wheel/setup.py b/python/packaging/bindings_wheel/setup.py index 32b9a730d..19184de77 100644 --- a/python/packaging/bindings_wheel/setup.py +++ b/python/packaging/bindings_wheel/setup.py @@ -29,6 +29,8 @@ tensorrt_module += "-cu##CUDA_MAJOR##_bindings" package_name += "_bindings" +plugin_subpackage_name = f"{package_name}.plugin" + setup( name=tensorrt_module, version="##TENSORRT_PYTHON_VERSION##", @@ -41,12 +43,12 @@ "Intended Audience :: Developers", "Programming Language :: Python :: 3", ], - packages=[package_name], + packages=[package_name, plugin_subpackage_name], extras_require={"numpy": "numpy"}, package_data={package_name: ["*.so*", "*.pyd", "*.pdb", "*.dll*"]}, include_package_data=True, zip_safe=True, keywords="nvidia tensorrt deeplearning inference", - url="https://developer.nvidia.com/tensorrt", - download_url="https://github.com/nvidia/tensorrt/tags", + url="https://github.com/nvidia/tensorrt", + download_url="https://developer.nvidia.com/tensorrt", ) diff --git a/python/packaging/bindings_wheel/tensorrt/plugin/__init__.py b/python/packaging/bindings_wheel/tensorrt/plugin/__init__.py new file mode 100644 index 000000000..c6d551904 --- /dev/null +++ b/python/packaging/bindings_wheel/tensorrt/plugin/__init__.py @@ -0,0 +1,46 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorrt as trt + +logger = trt.Logger() +logger.log(trt.Logger.WARNING, "Functionality provided through tensorrt.plugin module is experimental in TensorRT 10.6.") + +# export.public_api() will expose things here. To make sure that happens, we just need to +# import all the submodules so that the decorator is actually executed (__discover_modules() below). +__all__ = [] + +def __discover_modules(): + import importlib + import pkgutil + + mods = [importlib.import_module(__package__)] + while mods: + mod = mods.pop(0) + + yield mod + + if hasattr(mod, "__path__"): + mods.extend( + [ + importlib.import_module(f"{mod.__name__}.{submod.name}") + for submod in pkgutil.iter_modules(mod.__path__) + ] + ) + + +_ = list(__discover_modules()) diff --git a/python/packaging/bindings_wheel/tensorrt/plugin/_autotune.py b/python/packaging/bindings_wheel/tensorrt/plugin/_autotune.py new file mode 100644 index 000000000..6b9a6aac6 --- /dev/null +++ b/python/packaging/bindings_wheel/tensorrt/plugin/_autotune.py @@ -0,0 +1,270 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import builtins +import tensorrt as trt +from typing import List, Iterable +import copy + +from ._utils import _str_to_data_type +from ._export import public_api + + +# "onesided" means either type or format combinations. After combinations for each are separately generated, we will combine them later. +# e.g. io_variants = ["FP32|FP16", "FP32|FP16", "FP32*FP16"] for a plugin with 3 I/Os. i.e. I/O indices 0 and 1 are dependently either FP32/FP16 and index 2 is independently FP32/FP16. +# There will be 2 * 2 = 4 combinations here: ["FP32", "FP32", "FP32"], ["FP16", "FP16", "FP32"], ["FP32", "FP32", "FP16"], ["FP16", "FP16", "FP16"] +def _gen_onesided_combinations(io_variants): + + # Algorithm: + # (1) Ignore independent variants and count the (max) number of dependent variants `mx_poly` + # (2) Compile initial list of #`mx_poly` combinations using the first option (option 0) for any independent variants + # (3) For each independent variant IO index, add combinations with that index replaced by option 1, 2, ... + + combinations = [] + mx_poly = 0 # This is the number of dependent variants + + for io_variant in io_variants: + io_variant_list = io_variant.split("|") + + if len(io_variant_list) > 1: + if "*" in io_variant: + raise ValueError( + f"Type/Format '{io_variant}' contains both '|' and '*'" + ) + if mx_poly > 1: + if mx_poly != len(io_variant_list): + raise ValueError( + f"Type/Format combinations {io_variants} contain illegal dependent lengths" + ) + + mx_poly = builtins.max(mx_poly, len(io_variant_list)) + + for _ in range(mx_poly): + combinations.append([None] * len(io_variants)) + + for j, io_variant in enumerate(io_variants): + io_variant_list = io_variant.split("|") + + if len(io_variant_list) == 1: + if "*" in io_variant: + io_variant_list = io_variant.split("*") + for i in range(len(combinations)): + combinations[i][j] = io_variant_list[0] + else: + for k in range(len(io_variant_list)): + combinations[k][j] = io_variant_list[k] + + for j, io_variant in enumerate(io_variants): + new_combs = [] + if "*" in io_variant: + io_variant_list = io_variant.split("*") + for k in range(1, len(io_variant_list)): + for c in combinations: + new_c = copy.deepcopy(c) + new_c[j] = io_variant_list[k] + new_combs.append(new_c) + combinations.extend(new_combs) + + return combinations + + +class _TypeFormatCombination: + def __init__(self, num=0): + self.types = [None] * num + self.layouts = [None] * num + self.tactics = [] + + def set_types(self, types): + self.types = types + + def set_layouts(self, layouts=None): + if isinstance(layouts, List): + self.layouts = layouts + else: + self.layouts = [layouts] * len(self.types) + + def __hash__(self): + return hash((tuple(self.types), tuple(self.layouts))) + + def __eq__(self, other): + return ( + isinstance(other, _TypeFormatCombination) + and self.types == other.types + and self.layouts == other.layouts + ) + + def __str__(self) -> str: + return "{" + str(self.types) + ", " + str(self.layouts) + "}" + + +@public_api() +class AutoTuneCombination: + def __init__( + self, io_types: str = None, layouts: str = None, tactics: Iterable[int] = None + ): + """ + Construct a set of supported type/format combinations of a plugin's I/O. + + Any custom *tactic* s per each such type/format combination can also be advertised. A tactic is simply another way to + calculate the output of a plugin for the same type/format combination of the I/O (e.g. if there are multiple kernels available). + + Args: + io_types (str, optional): A string representation of a type combination. + + Valid format is "type0,type1,...,type#io" where 'type' is of the form "TYPE0[sep]TYPE1[sep]...". + + TYPE is a valid string representation of a `trt.DataType`. These include "FP32" for trt.float32, "FP16" for trt.float16. The string representation of other data types is the same as their name in the trt.DataType enum. + + + [sep] is a valid separator, which is either '|' or '*'. Only one of these separators can appear in a given `io_types`. + + (1). '|' indicates a dependent combination: the dependence of the type of one I/O to another I/O. e.g. "FP32|FP16,FP32|FP16" indicates the IO can only be both FP32 or both FP16. + + (2). '*' indicates an independent combination. e.g. "FP32*FP16,FP32|FP16,FP32|FP16" indicates that the first input is independently either FP32 or FP16 regardless of the rest of the IO. + + layouts (str, optional): A string representation of a format combination. + + Valid format is "format0,format1,...,format#io" where 'format' is of the form "FORMAT0[sep]FORMAT1[sep]...". + + FORMAT is a valid string representation of a `trt.TensorFormat`. These are string versions for the enum values of `trt.TensorFormat`. e.g. "LINEAR" for `trt.TensorFormat.LINEAR`. + + [sep] is a valid separator, which is either '|' or '*'. The rules are the same as for `io_types`. + + tactics (Iterable[int], optional): Custom tactics for this type/format combination. Each custom tactic must be a positive integer. Defaults to default tactic (0). + + .. code-block:: python + :linenos: + :caption: For a plugin with 3 I/Os, I/O indices 0 and 1 are dependently either FP32/FP16 and index 2 is independently FP32/FP16. + + @trtp.autotune("my::plugin") + def autotune(inp0: trtp.TensorDesc, inp1: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc]) -> List[trtp.AutoTuneCombination]: + # The following would result in the following type combinations: + # [FP32, FP32, FP32], [FP16, FP16, FP32], [FP32, FP32, FP16], [FP16, FP16, FP16] + return [trtp.AutoTuneCombination("FP32|FP16, FP32|FP16, FP32|FP16", "LINEAR", [1, 2])] + + .. code-block:: python + :linenos: + :caption: For a plugin with 2 I/Os, the input/output supports either LINEAR or HWC format for FP32 and LINEAR format for FP16. + + @trtp.autotune("my::plugin") + def autotune(inp0: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc]) -> List[trtp.AutoTuneCombination]: + # Even though (FP16, HWC) is not a valid combination (see next example), TRT should intelligently reject those + # and pass the following combinations to the impl function: + # [{FP32, FP32}, {LINEAR, LINEAR}], [{FP32, FP32}, {HWC, LINEAR}], [{FP16, FP32}, {LINEAR, LINEAR}] + return [trtp.AutoTuneCombination("FP32*FP16, FP32", "LINEAR*HWC, LINEAR", [1, 2])] + + .. code-block:: python + :linenos: + :caption: For a plugin with 2 I/Os, the input/output supports either LINEAR or HWC format for FP32 and LINEAR format for FP16 (second method). + + @trtp.autotune("my::plugin") + def autotune(inp0: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc]) -> List[trtp.AutoTuneCombination]: + # We can use two AutoTuneCombination objects to avoid communicating illegal combinations + return [trtp.AutoTuneCombination("FP32*FP16, FP32", "LINEAR, LINEAR", [1, 2]), trtp.AutoTuneCombination("FP32, FP32", "HWC, LINEAR", [1, 2])] + """ + + if io_types is not None: + self.io_types = [s.strip() for s in io_types.split(",")] + if layouts is None: + layouts = "LINEAR" + self.layouts = [s.strip() for s in layouts.split(",")] + + if len(self.layouts) > 1: + assert len(self.io_types) == len(self.layouts) + + if len(self.io_types) > len(self.layouts): + assert len(self.layouts) == 1 + self.layouts = [self.layouts[0]] * len(self.io_types) + else: + self.io_types = [] + self.layouts = [] + + self.combinations = [] + self._tactics = tactics + + def pos(self, pos: Iterable[int], io_types: str, layouts: str = "LINEAR") -> None: + """ + Specify I/O types and formats for a specified set of I/O indices. + + Args: + pos (Iterable[int]): I/O indices. Input indices are [0, 1, ..., #inputs - 1] and output indices are [#inputs, #inputs + 1, ..., #inputs + #outputs - 1]. + io_types (str): Data types for these I/O indices. + layouts (str, optional): Tensor format(s) for these I/O indices. Defaults to "LINEAR". + Raises: + ValueError: If types or layouts for any of these I/O indices is already specified. + + .. code-block:: python + :linenos: + :caption: For a plugin with 3 I/Os, I/O indices 0 and 1 are dependently either FP32/FP16 and index 2 is independently FP32/FP16. + + @trtp.autotune("my::plugin") + def autotune(inp0: trtp.TensorDesc, inp1: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc]) -> List[trtp.AutoTuneCombination]: + c = trtp.AutoTuneCombination() + c.pos([0, 1], "FP32|FP16", "LINEAR") + c.pos(2, "FP32*FP16") # Omitting format is the same as declaring it to be LINEAR. + c.tactics([1, 2]) + return [c] + """ + if max(pos) >= len(self.io_types): + self.io_types.extend([None] * (max(pos) + 1 - len(self.io_types))) + self.layouts.extend([None] * (max(pos) + 1 - len(self.layouts))) + assert len(self.io_types) == len(self.layouts) + + for p in pos: + if self.io_types[p] is not None: + raise ValueError(f"Type(s) for position {p} already specified") + if self.layouts[p] is not None: + raise ValueError(f"Layout(s) for position {p} already specified") + self.io_types[p] = io_types + self.layouts[p] = layouts + + def tactics(self, tactics: Iterable[int]) -> None: + """ + Specify custom tactics for this type/format combination + + Args: + tactics (Iterable[int]): Custom tactics. These must be positive integers. + """ + self._tactics = tactics + + def _generate_combinations(self): + + self.combinations = [] + + type_combinations = _gen_onesided_combinations(self.io_types) + layout_combinations = _gen_onesided_combinations(self.layouts) + + for t in type_combinations: + for l in layout_combinations: + c = _TypeFormatCombination(len(self.io_types)) + c.types = [_str_to_data_type(tt) for tt in t] + c.layouts = [getattr(trt.TensorFormat, ff) for ff in l] + c.tactics = self._tactics + self.combinations.append(c) + + def _get_combinations(self): + self._generate_combinations() + return self.combinations + + def _check(self, pos, type, layout): + for i in range(len(self.combinations)): + if ( + self.combinations[i].types[pos] == _str_to_data_type(type) + and self.combinations[i].layouts[pos] == layout.name + ): + return True + return False diff --git a/python/packaging/bindings_wheel/tensorrt/plugin/_export.py b/python/packaging/bindings_wheel/tensorrt/plugin/_export.py new file mode 100644 index 000000000..d4ce23989 --- /dev/null +++ b/python/packaging/bindings_wheel/tensorrt/plugin/_export.py @@ -0,0 +1,36 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from types import ModuleType +import importlib + +def public_api(module: ModuleType = None, symbol: str = None): + def export_impl(obj): + nonlocal module, symbol + + module = module or importlib.import_module(__package__) + symbol = symbol or obj.__name__ + + if not hasattr(module, "__all__"): + module.__all__ = [] + + module.__all__.append(symbol) + setattr(module, symbol, obj) + + return obj + + return export_impl diff --git a/python/packaging/bindings_wheel/tensorrt/plugin/_lib.py b/python/packaging/bindings_wheel/tensorrt/plugin/_lib.py new file mode 100644 index 000000000..eabc437a1 --- /dev/null +++ b/python/packaging/bindings_wheel/tensorrt/plugin/_lib.py @@ -0,0 +1,523 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorrt as trt +import types +import typing +from typing import Callable, Tuple, List +import numpy as np + +from ._plugin_class import _TemplatePlugin +from ._validate import ( + _parse_register_inputs, + _parse_register_return, + _validate_autotune, + _validate_impl, + _validate_name_and_namespace, +) +from ._utils import ( + _built_in_to_plugin_field_type, + _join_with, + _numpy_to_plugin_field_type, + _is_numpy_array, + _infer_numpy_type, +) + +from ._export import public_api + +# Namespace to which plugins are dynamically bound +# A namespace can be thought of as a library of plugins from the same author/common objective +class _PluginNamespace(types.ModuleType): + def __init__(self, namespace): + super().__init__("tensorrt.plugin.op." + namespace) + self._namespace = namespace + + def define(self, name, plugin_def): + assert not hasattr(self, name) + setattr(self, name, plugin_def) + + def __getattr__(self, name): + raise AttributeError( + f"'{self.__class__.__name__}' object '{self._namespace}' has no attribute '{name}'" + ) + + def __repr__(self): + return f'_PluginNamespace(namespace="{self._namespace}")' + + +# `tensorrt.plugin.op` module to which plugin namespaces are dynamically bound +class _Op(types.ModuleType): + def __init__(self): + super().__init__("tensorrt.plugin.op") + + def define_or_get(self, namespace): + if hasattr(self, namespace): + return getattr(self, namespace) + + ns = _PluginNamespace(namespace) + setattr(self, namespace, ns) + + return ns + + def __getattr__(self, name): + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) + + +op = _Op() +public_api(symbol="op")(op) + +QDP_CREATORS = {} +QDP_REGISTRY = {} + +# Contains metadata about a registered plugin and `__call__()`` that allows for a plugin instance to be created +class PluginDef: + def __init__(self): + self.plugin_id = None # includes namespace (format is ns::name) + self.register_func = None + self.impl_func = None + self.autotune_func = None + self.autotune_attr_names = None + self.input_tensor_names = None + self.input_attrs = None # map name -> type + self.impl_attr_names = None + self.num_outputs = None + self.input_arg_schema = None + self.expects_tactic = None + + def __call__( + self, *args, **kwargs + ) -> Tuple[List[trt.ITensor], List[trt.ITensor], trt.IPluginV3]: + namespace, name = self.plugin_id.split("::") + + input_tensors = [] + schema_chunks = [] + + for t in args: + if not isinstance(t, trt.ITensor): + raise ValueError( + f"Expected trt.ITensor but got input of type {type(t)}" + ) + + schema_chunks.append("ITensor") + input_tensors.append(t) + + attrs = {} + for key, value in kwargs.items(): + if key not in self.input_attrs: + raise ValueError( + f"Unexpected attribute {key} provided. Expected one of {self.input_attrs.keys()}." + ) + attrs[key] = value + attr_annotation = self.input_attrs[key] + if isinstance(value, np.ndarray): + if typing.get_origin(attr_annotation) == np.ndarray: + np_dtype = typing.get_args(typing.get_args(attr_annotation)[1])[0] + if np.dtype(np_dtype) != np.dtype(value.dtype): + raise ValueError( + f"Unexpected dtype '{np.dtype(value.dtype)}' for attribute '{key}'. Expected '{np_dtype}'." + ) + else: + if attr_annotation is not type(value): + raise ValueError( + f"Unexpected type '{type(value)}' for attribute '{key}'. Expected '{attr_annotation}'." + ) + + schema_chunks.append(key) + + expected_schema = ( + f"({_join_with(['ITensor'] * len(self.input_tensor_names))}" + + _join_with(self.input_attrs.keys(), True) + + ")" + ) + schema = f"({', '.join(schema_chunks)})" + + if schema != expected_schema: + raise ValueError( + f"Unexpected schema {schema} received. Expected {expected_schema}." + ) + + if self.plugin_id in QDP_CREATORS: + plg_creator = trt.get_plugin_registry().get_creator(name, "1", namespace) + else: + attrs_types = {} + for key, value in kwargs.items(): + if isinstance(value, np.ndarray): + attrs_types[key] = (False, value.dtype) # (builtin?, type) + else: + attrs_types[key] = (True, type(value)) # (builtin?, type) + + plg_creator = _register_plugin_creator(name, namespace, attrs_types) + + fields = [] + for key, value in attrs.items(): + if isinstance(value, np.ndarray): + np_type = np.dtype(value.dtype) + if np_type == np.float16: + fields.append( + trt.PluginField( + key, value.tobytes(), trt.PluginFieldType.UNKNOWN + ) + ) + else: + fields.append( + trt.PluginField( + key, value, _numpy_to_plugin_field_type[np_type] + ) + ) + elif isinstance(value, str): + fields.append( + trt.PluginField(key, value.encode(), trt.PluginFieldType.CHAR) + ) + elif isinstance(value, bytes): + fields.append(trt.PluginField(key, value, trt.PluginFieldType.UNKNOWN)) + else: + fields.append( + trt.PluginField( + key, + np.array([value]), + _built_in_to_plugin_field_type[type(value)], + ) + ) + + plg = plg_creator.create_plugin( + name, + namespace, + trt.PluginFieldCollection(fields), + trt.TensorRTPhase.BUILD, + ) + plg.init( + self.register_func, + attrs, + self.impl_attr_names, + self.impl_func, + self.autotune_attr_names, + self.autotune_func, + self.expects_tactic, + ) + + return input_tensors, [], plg + + +class _TemplatePluginCreator(trt.IPluginCreatorV3Quick): + def __init__(self, name, namespace, attrs): + trt.IPluginCreatorV3Quick.__init__(self) + self.name = name + self.plugin_namespace = namespace + self.plugin_version = "1" + field_names = [] + for name, (builtin, type_) in attrs.items(): + if builtin: + if type_ is str: + field_names.append( + trt.PluginField(name, b"", trt.PluginFieldType.CHAR) + ) + elif type_ is bytes: + field_names.append( + trt.PluginField(name, b"", trt.PluginFieldType.UNKNOWN) + ) + else: + field_names.append( + trt.PluginField( + name, np.array([]), _built_in_to_plugin_field_type[type_] + ) + ) + else: + field_names.append( + trt.PluginField( + name, np.array([]), _numpy_to_plugin_field_type[np.dtype(type_)] + ) + ) + + self.field_names = trt.PluginFieldCollection(field_names) + + def create_plugin(self, name, namespace, fc, phase): + desc = QDP_REGISTRY[f"{namespace}::{name}"] + name = name + namespace = namespace + + attrs = {} + for f in fc: + if f.name not in desc.input_attrs: + raise AssertionError( + f"Unexpected attribute {f.name} provided to create_plugin. Expected one of {desc.input_attrs.keys()}." + ) + + attr_type_annot = desc.input_attrs[f.name] + if _is_numpy_array(attr_type_annot): + np_type = _infer_numpy_type(attr_type_annot) + if np_type == np.float16: + attrs[f.name] = np.frombuffer(f.data.tobytes(), dtype=np.float16) + else: + attrs[f.name] = f.data.astype(np_type) + else: + if issubclass(attr_type_annot, str): + attrs[f.name] = f.data.tobytes().decode("utf-8") + else: + attrs[f.name] = attr_type_annot(f.data) + + plg = _TemplatePlugin(name, namespace, desc.num_outputs) + plg.init( + desc.register_func, + attrs, + desc.impl_attr_names, + desc.impl_func, + desc.autotune_attr_names, + desc.autotune_func, + desc.expects_tactic, + ) + return plg + + +def _register_plugin_creator(name: str, namespace: str, attrs_types): + plg_registry = trt.get_plugin_registry() + plg_creator = _TemplatePluginCreator(name, namespace, attrs_types) + plg_registry.register_creator(plg_creator, namespace) + plg_creator = plg_registry.get_creator(name, "1", namespace) + QDP_CREATORS[f"{name}::{namespace}"] = plg_creator + return plg_creator + + +# Decorator for `tensorrt.plugin.register` +# By default, the plugin will be immediately registered in the TRT plugin registry +# During plugin development/when building engine, lazy registration may be used to delay plugin registration until the plugin is explicitly instantiated using `trt.plugin.op.ns.plugin_name(...)` +@public_api() +def register(plugin_id: str, lazy_register: bool = False) -> Callable: + """ + Wraps a function to register and describe a TensorRT plugin's IO characteristics. In addition, a complete plugin at least needs an `trt.plugin.impl` function to be registered. + + This API is only intended to be used as a decorator. The decorated function must have type hints for all inputs as well as return value. + + .. code-block:: text + + (inp0: TensorDesc, inp1: TensorDesc, ..., attr0: SupportedAttrType, attr1: SupportedAttrType, ...) -> Union[TensorDesc, Tuple[TensorDesc]] + + * Input tensors are declared first, each described by a tensor descriptor TensorDesc. + * Plugin attributes are declared next. "SupportedAttrType" must be one of: + * Supported built-in types: int, float, str, bool, bytes (Note: Lists/tuples of these types are not supported) + * 1-D Numpy arrays of the following types: int8, int16, int32, int64, float16, float32, float64, bool. These must be annotated with 'numpy.typing.NDArray[dtype]', where 'dtype' is the expected numpy dtype. + * If the plugin has only one output, the return annotation could be TensorDesc. Tuple[TensorDesc] could be used for any number of outputs. + + By default, the plugin will be immediately registered in the TRT plugin registry. Use the lazy_register argument to change this. + + Args: + plugin_id: An ID for the plugin in the form "{namespace}::{name}", + e.g. "my_project::add_plugin". The namespace is used to avoid collisions + so using your product/project name is recommended. + + lazy_register: During plugin development/when building engine, lazy registration may be used to delay plugin registration until the plugin is explicitly instantiated using `trt.plugin.op.ns.plugin_name(...)` + + .. code-block:: python + :linenos: + :caption: Registration of an elementwise plugin (output has same characteristics as the input) + + import tensorrt.plugin as trtp + + @trtp.register("my::add_plugin") + def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]: + return inp0.like() + + """ + + def decorator(register_func: Callable): + + plugin_ns, plugin_name = plugin_id.split("::") + _validate_name_and_namespace(plugin_ns, plugin_name) + + op_namespace = op.define_or_get(plugin_ns) + + if hasattr(op_namespace, plugin_name): + raise ValueError( + f"'{op.__class__.__name__}' already has a defintion for '{plugin_name}'" + ) + + ( + tensor_names, + input_attrs, + input_arg_schema, + attrs_types, + ) = _parse_register_inputs(register_func, lazy_register) + + plugin_def = PluginDef() + plugin_def.plugin_id = plugin_id + plugin_def.register_func = register_func + plugin_def.input_tensor_names = tensor_names + plugin_def.input_attrs = input_attrs + plugin_def.input_arg_schema = input_arg_schema + + num_outputs = _parse_register_return(register_func) + + plugin_def.num_outputs = num_outputs + QDP_REGISTRY[plugin_id] = plugin_def + + if not lazy_register: + _register_plugin_creator(plugin_name, plugin_ns, attrs_types) + + op_namespace.define(plugin_name, plugin_def) + + return register_func + + return decorator + + +# Decorator for `tensorrt.plugin.impl` +@public_api() +def impl(plugin_id: str) -> Callable: + """ + Wraps a function to define an implementation for a plugin already registered through `trt.plugin.register`. + + This API is only intended to be used as a decorator. The decorated function is not required to have type hints for input arguments or return value; + however, any type hints specified will be validated against the `trt.plugin.register` signature for consistency. + + The schema for the function is as follows: + + .. code-block:: text + + (inp0: Tensor, inp1: Tensor, ..., attr0: SupportedAttrType, attr1: SupportedAttrType, outputs: Tuple[Tensor], stream: int, tactic: Optional[int]) -> None + + * Input tensors are passed first, each described by a `Tensor`. + * Plugin attributes are declared next. + * Not all attributes included in `trt.plugin.register` must be specified here -- they could be a subset. + * Included attributes will be serialized to the TRT engine. Therefore, only attributes the plugin actually needs to perform inference (within the body of `trt.plugin.impl`) should be included. + * `tactic` is an optional argument. If the plugin is using custom tactics, it must be specified to receive the tactic value to use for the current execution of the plugin. + + Args: + plugin_id: The ID for the plugin in the form "{namespace}::{name}", which must match that used during `trt.plugin.register` + + .. code-block:: python + :linenos: + :caption: Implementation of an elementwise plugin with an OpenAI Triton kernel + + import tensorrt.plugin as trtp + import triton + import triton.language as tl + + @triton.jit + def add_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + tl.store(y_ptr + offsets, x + 1, mask=mask) + + @trtp.register("my::add_plugin") + def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]: + return inp0.like() + + @trtp.impl("my::add_plugin") + def add_plugin_impl(inp0: trtp.Tensor, block_size: int, outputs: Tuple[trtp.Tensor], stream: int) -> None: + + n = inp0.numel() + inp0_t = torch.as_tensor(inp0, device="cuda") + out_t = torch.as_tensor(outputs[0], device="cuda") + + add_kernel[(triton.cdiv(n, block_size),)](inp0_t, out_t, n, BLOCK_SIZE = block_size) + """ + + def decorator(impl_func: Callable): + if plugin_id not in QDP_REGISTRY: + raise ValueError( + f"Plugin {plugin_id} is not registered. Did you register it with tensorrt.plugin.register API?" + ) + + plugin_def = QDP_REGISTRY[plugin_id] + impl_attr_names, found_tactic = _validate_impl(impl_func, plugin_def) + + plugin_def.impl_func = impl_func + plugin_def.impl_attr_names = impl_attr_names + plugin_def.expects_tactic = found_tactic + return impl_func + + return decorator + + +# Decorator for `tensorrt.plugin.autotune` +@public_api() +def autotune(plugin_id: str) -> Callable: + """ + Wraps a function to define autotune logic for a plugin already registered through `trt.plugin.register`. + + Autotuning is the process by which TensorRT executes the plugin over IO type/format combinations, and any custom tactics advertised as being supported by the plugin. + The (type, format, tactic) combination with the lowest latency is used to execute the plugin once the engine is built. + + .. note:: An autotune function is optional. If not specified, TensorRT will assume the plugin only supports input types specified at network creation, output types specifeid through `trt.plugin.register`, and linear formats for all I/O. + + This API is only intended to be used as a decorator. The decorated function is not required to have type hints for input arguments or return value; however, any type hints specified will be validated against the `trt.plugin.register` signature for consistency. + + The schema for the function is as follows: + + .. code-block:: text + + (inp0: TensorDesc, inp1: TensorDesc, ..., attr0: SupportedAttrType, attr1: SupportedAttrType, outputs: Tuple[TensorDesc]) -> List[AutoTuneCombination] + + * Input tensors are passed first, each described by a :class:`TensorDesc`. + * Plugin attributes are declared next. Not all attributes included in `trt.plugin.register` must be specified here -- they could be a subset. + * The function should return a list of :class:`AutoTuneCombination`\s. + + Args: + plugin_id: The ID for the plugin in the form "{namespace}::{name}", which must match that used during `trt.plugin.register` + + .. code-block:: python + :linenos: + :caption: An elementwise add plugin which supports both FP32 and FP16 linear I/O and wants to be tuned over 2 custom tactics. + + import tensorrt.plugin as trtp + + @trtp.register("my::add_plugin") + def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]: + return inp0.like() + + @trtp.autotune("my::add_plugin") + def add_plugin_autotune(inp0: trtp.TensorDesc, block_size: int, outputs: Tuple[trtp.TensorDesc]) -> List[trtp.AutoTuneCombination]: + + return [trtp.AutoTuneCombination("FP32|FP16, FP32|FP16", "LINEAR", [1, 2])] + + .. code-block:: python + :linenos: + :caption: Same as above example but using index-by-index construction of an `AutoTuneCombination` + + import tensorrt.plugin as trtp + + @trtp.register("my::add_plugin") + def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]: + return inp0.like() + + @trtp.autotune("my::add_plugin") + def add_plugin_autotune(inp0: trtp.TensorDesc, block_size: int, outputs: Tuple[trtp.TensorDesc]) -> List[trtp.AutoTuneCombination]: + c = trtp.AutoTuneCombination() + c.pos(0, "FP32|FP16", "LINEAR") + c.pos(1, "FP32|FP16") # index 1 is the output. Omitting format is the same as declaring it to be LINEAR. + c.tactics([1, 2]) + return [c] + """ + + def decorator(autotune_func: Callable): + if plugin_id not in QDP_REGISTRY: + raise ValueError( + f"Plugin {plugin_id} is not registered. Did you register it with tensorrt.plugin.register API?" + ) + + plugin_def = QDP_REGISTRY[plugin_id] + autotune_attr_names = _validate_autotune(autotune_func, plugin_def) + + plugin_def.autotune_func = autotune_func + plugin_def.autotune_attr_names = autotune_attr_names + + return autotune_func + + return decorator diff --git a/python/packaging/bindings_wheel/tensorrt/plugin/_plugin_class.py b/python/packaging/bindings_wheel/tensorrt/plugin/_plugin_class.py new file mode 100644 index 000000000..c09104217 --- /dev/null +++ b/python/packaging/bindings_wheel/tensorrt/plugin/_plugin_class.py @@ -0,0 +1,322 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorrt as trt +from typing import Tuple + +import numpy as np +from ._utils import _numpy_to_plugin_field_type, _built_in_to_plugin_field_type +from ._tensor import TensorDesc, Tensor, Shape, ShapeExpr, ShapeExprs +from ._autotune import _TypeFormatCombination + + +class _TemplatePlugin( + trt.IPluginV3, + trt.IPluginV3QuickCore, + trt.IPluginV3QuickBuild, + trt.IPluginV3QuickRuntime, +): + def __init__(self, name, namespace, num_outputs): + trt.IPluginV3.__init__(self) + trt.IPluginV3QuickCore.__init__(self) + trt.IPluginV3QuickBuild.__init__(self) + trt.IPluginV3QuickRuntime.__init__(self) + + self.plugin_version = "1" + self.input_types = [] + self.aliased_map = {} # output index -> input index + + self.plugin_namespace = namespace + self.plugin_name = name + self.num_outputs = num_outputs + + self.autotune_combs = [] + self.supported_combs = {} + self.curr_comb = None + self.expects_tactic = False + + def init( + self, + register_function, + attrs, + impl_attr_names, + impl_function, + autotune_attr_names, + autotune_function, + expects_tactic, + ): + self.register_function = register_function + self.impl_function = impl_function + self.attrs = attrs + self.impl_attr_names = impl_attr_names + self.autotune_attr_names = autotune_attr_names + self.autotune_function = autotune_function + self.expects_tactic = expects_tactic + + def get_capability_interface(self, type): + return self + + def get_output_data_types(self, input_types, ranks): + self.input_types = input_types + + input_descs = [None] * len(input_types) + input_desc_map = {} + for i in range(len(input_types)): + input_descs[i] = TensorDesc() + input_descs[i].dtype = input_types[i] + input_descs[i].shape_expr = ShapeExprs(ranks[i], _is_dummy=True) + input_descs[i]._immutable = True + input_desc_map[id(input_descs[i])] = i + + output_descs = self.register_function(*input_descs, **self.attrs) + if not isinstance(output_descs, Tuple): + output_descs = tuple([output_descs]) + + self.output_types = [] + + for i in range(len(output_descs)): + self.output_types.append(output_descs[i].dtype) + + if output_descs[i].get_aliased() is not None: + self.aliased_map[i] = input_desc_map[id(output_descs[i].get_aliased())] + else: + self.aliased_map[i] = -1 + + return self.output_types + + def get_fields_to_serialize(self): + fields = [] + for key, value in self.attrs.items(): + if key in self.impl_attr_names: + if isinstance(value, np.ndarray): + if np.dtype(value.dtype) == np.float16: + fields.append( + trt.PluginField( + key, value.tobytes(), trt.PluginFieldType.UNKNOWN + ) + ) + else: + fields.append( + trt.PluginField( + key, + value, + _numpy_to_plugin_field_type[np.dtype(value.dtype)], + ) + ) + elif isinstance(value, str): + fields.append( + trt.PluginField(key, value.encode(), trt.PluginFieldType.CHAR) + ) + elif isinstance(value, bytes): + fields.append( + trt.PluginField(key, value, trt.PluginFieldType.UNKNOWN) + ) + else: + fields.append( + trt.PluginField( + key, + np.array([value]), + _built_in_to_plugin_field_type[type(value)], + ) + ) + + return trt.PluginFieldCollection(fields) + + def get_output_shapes(self, inputs, shape_inputs, exprBuilder): + assert len(shape_inputs) == 0 # Shape inputs are not yet supported for QDPs + ShapeExpr._exprBuilder = exprBuilder + self.input_descs = [] + for i in range(len(inputs)): + desc = TensorDesc() + inp = inputs[i] + + desc.dtype = self.input_types[i] + desc.shape_expr = ShapeExprs(len(inp)) + for j in range(len(inp)): + desc.shape_expr[j] = ShapeExpr(inp[j]) + desc._immutable = True + + self.input_descs.append(desc) + + self.output_descs = self.register_function(*self.input_descs, **self.attrs) + if not isinstance(self.output_descs, Tuple): + self.output_descs = tuple([self.output_descs]) + + for idx, desc in enumerate(self.output_descs): + if desc.is_size_tensor: + desc._set_index(idx) + + output_exprs = [] + for i in range(len(self.output_descs)): + exprs = trt.DimsExprs(len(self.output_descs[i].shape_expr)) + for j in range(len(exprs)): + exprs[j] = self.output_descs[i].shape_expr[j]._expr + + output_exprs.append(exprs) + + return output_exprs + + def configure_plugin(self, inputs, outputs): + self.curr_comb = _TypeFormatCombination() + self.curr_comb.types = [inp.desc.type for inp in inputs] + [ + out.desc.type for out in outputs + ] + self.curr_comb.layouts = [inp.desc.format for inp in inputs] + [ + out.desc.format for out in outputs + ] + + def get_supported_format_combinations(self, in_out, num_inputs): + if self.autotune_function is not None: + if len(self.autotune_attr_names) > 0: + val = [self.attrs[k] for k in self.autotune_attr_names] + else: + val = () + + for i, desc in enumerate(in_out): + if i < num_inputs: + self.input_descs[i]._immutable = False + self.input_descs[i].shape = Shape(desc) + self.input_descs[i].format = desc.desc.format + self.input_descs[i].scale = desc.desc.scale + self.input_descs[i]._immutable = True + else: + self.output_descs[i - num_inputs]._immutable = False + self.output_descs[i - num_inputs].shape = Shape(desc) + self.output_descs[i - num_inputs].format = desc.desc.format + self.output_descs[i - num_inputs].scale = desc.desc.scale + self.output_descs[i - num_inputs]._immutable = True + + self.autotune_combs = self.autotune_function( + *self.input_descs, *val, self.output_descs + ) + + if len(self.autotune_combs) == 0: + default_comb = [None] * len(in_out) + comb = _TypeFormatCombination(len(in_out)) + for j in range(len(in_out)): + default_comb[j] = trt.PluginTensorDesc() + default_comb[j].type = ( + self.input_types[j] + if j < num_inputs + else self.output_descs[j - num_inputs].dtype + ) + default_comb[j].format = trt.TensorFormat.LINEAR + comb.types[j] = default_comb[j].type + comb.layouts[j] = default_comb[j].format + + self.supported_combs[comb] = set() + + return default_comb + + all_combs = [] + for comb in self.autotune_combs: + all_combs.extend(comb._get_combinations()) + + ret_supported_combs = [] + self.supported_combs = {} + + for i, comb in enumerate(all_combs): + value = self.supported_combs.get(comb) + if value is not None: + value.update(set(comb.tactics) if comb.tactics is not None else set()) + else: + self.supported_combs[comb] = ( + set(comb.tactics) if comb.tactics is not None else set() + ) + for j in range(len(in_out)): + curr_comb = trt.PluginTensorDesc() + curr_comb.type = comb.types[j] + curr_comb.format = comb.layouts[j] + ret_supported_combs.append(curr_comb) + + return ret_supported_combs + + def enqueue( + self, + input_desc, + output_desc, + inputs, + outputs, + in_strides, + out_strides, + stream, + ): + input_tensors = [None] * (len(inputs)) + aliased_input_idxs = list(self.aliased_map.values()) + + for i in range(len(inputs)): + input_tensors[i] = Tensor() + input_tensors[i].dtype = input_desc[i].type + input_tensors[i].shape = Shape(input_desc[i]) + input_tensors[i].format = input_desc[i].format + input_tensors[i].scale = input_desc[i].scale + input_tensors[i].data_ptr = inputs[i] + input_tensors[i]._stream = stream + input_tensors[i]._read_only = i not in aliased_input_idxs + input_tensors[i].strides = in_strides[i] + + output_tensors = [None] * (len(outputs)) + for i in range(len(outputs)): + output_tensors[i] = Tensor() + output_tensors[i].dtype = output_desc[i].type + output_tensors[i].shape = Shape(output_desc[i]) + output_tensors[i].format = output_desc[i].format + output_tensors[i].scale = output_desc[i].scale + output_tensors[i].data_ptr = outputs[i] + output_tensors[i]._stream = stream + output_tensors[i]._read_only = False + output_tensors[i].strides = out_strides[i] + + for i, j in self.aliased_map.items(): + output_tensors[i]._aliased_to = input_tensors[j] + input_tensors[j]._aliased_to = output_tensors[i] + + for t in input_tensors: + t._immutable = True + + for t in output_tensors: + t._immutable = True + + if len(self.impl_attr_names) > 0: + val = [self.attrs[k] for k in self.impl_attr_names] + else: + val = () + + if self.expects_tactic: + self.impl_function( + *input_tensors, *val, output_tensors, stream, self._tactic + ) + else: + self.impl_function(*input_tensors, *val, output_tensors, stream=stream) + + def get_aliased_input(self, output_index: int): + return self.aliased_map[output_index] + + def get_valid_tactics(self): + tactics = self.supported_combs.get(self.curr_comb) + assert tactics is not None + return list(tactics) + + def set_tactic(self, tactic): + self._tactic = tactic + + def clone(self): + cloned_plugin = _TemplatePlugin( + self.plugin_name, self.plugin_namespace, self.num_outputs + ) + cloned_plugin.__dict__.update(self.__dict__) + return cloned_plugin diff --git a/python/packaging/bindings_wheel/tensorrt/plugin/_tensor.py b/python/packaging/bindings_wheel/tensorrt/plugin/_tensor.py new file mode 100644 index 000000000..5e9e711b2 --- /dev/null +++ b/python/packaging/bindings_wheel/tensorrt/plugin/_tensor.py @@ -0,0 +1,808 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorrt as trt +from typing import Tuple, Union +import numpy as np +from ._export import public_api + +# Symbolic expression for a given dimension of a tensor +@public_api() +class ShapeExpr: + """ + Symbolic expression for single dimension of a tensor + """ + _exprBuilder = None # trt.IExprBuilder instance. Populated when a shape-calculation context is entered. + + def __init__(self, value: Union[int, trt.IDimensionExpr, "ShapeExpr"] = None): + """ + Args: + value (Union[int, trt.IDimensionExpr, ShapeExpr], optional): Constant or another symbolic expression. Defaults to creating a fake shape expression. + """ + self._is_dummy = False + self._dim_expr = None + self._is_size_tensor = False + if value is None: + self._is_dummy = True + elif isinstance(value, int): + if self._exprBuilder is None: + self._dim_expr = None + self._is_dummy = True + else: + self._dim_expr = ShapeExpr._exprBuilder.constant(value) + elif isinstance(value, trt.IDimensionExpr): + self._dim_expr = value + elif isinstance(value, ShapeExpr): + self._dim_expr = value._dim_expr + self._is_dummy = value._is_dummy + self._is_size_tensor = value._is_size_tensor + + def _op(self, op: trt.DimensionOperation, other: Union[int, "ShapeExpr"]): + if self._is_size_tensor: + raise ValueError("It is not permitted to perform binary operations on size tensor expressions") # trt limitation + if self._is_dummy: + return ShapeExpr() + if isinstance(other, int): + other = ShapeExpr(other) + return ShapeExpr(ShapeExpr._exprBuilder.operation(op, self._expr, other._expr)) + + # Binary operations for +, -, *, //, ==. < + # Those for ceil_div, max and min are provided as top-level functions of tensorrt.plugin + def __add__(self, other: Union[int, "ShapeExpr"]): + return self._op(trt.DimensionOperation.SUM, other) + + def __sub__(self, other: Union[int, "ShapeExpr"]): + return self._op(trt.DimensionOperation.SUB, other) + + def __mul__(self, other: Union[int, "ShapeExpr"]): + return self._op(trt.DimensionOperation.PROD, other) + + def __floordiv__(self, other: Union[int, "ShapeExpr"]): + return self._op(trt.DimensionOperation.FLOOR_DIV, other) + + def __eq__(self, other: Union[int, "ShapeExpr"]): + return self._op(trt.DimensionOperation.EQUAL, other) + + def __lt__(self, other: Union[int, "ShapeExpr"]): + return self._op(trt.DimensionOperation.LESS, other) + + def __repr__(self): + if self._is_dummy: + return f"FakeShapeExpr[id={id(self)}]" + elif not self.is_constant: + return f"ShapeExpr[id={id(self)}]" + return f"ShapeExpr[{self._expr.get_constant_value()}]" + + # A ShapeExpr may be "fake" when it is accessed in a non-shape calculation context. Fake `ShapeExpr`s are externally indistinguishable unless `is_constant` or `constant_value` is required. + # Therefore, constant checks/access must occur conditionally after evaluating `is_fake`. + @property + def is_fake(self) -> bool: + """ + A ShapeExpr may be "fake" when it is accessed in a non-shape calculation context. + Fake `ShapeExpr`s are externally indistinguishable unless `is_constant` or `constant_value` is required. + """ + return self._is_dummy + + @property + def is_size_tensor(self) -> bool: + """ + `True` if this represents a size tensor, `False` otherwise. + """ + return self._is_size_tensor + + @property + def is_constant(self) -> bool: + """ + `True` if this shape expression is a build-time constant, `False` otherwise. + + Raises: + RuntimeError: For fake :class:`ShapeExpr`\s. Check :attr:`is_fake` to determine accessibility. + """ + if self._is_dummy: + raise RuntimeError( + "Not accessible for fake 'ShapeExpr's. Check is_fake to determine accessibility." + ) + return self._expr.is_constant() + + def constant_value(self) -> int: + """ + Return value of the constant shape expression. + + Raises: + RuntimeError: For non-constant shape expressions. Check :attr:`is_constant` to determine accessibility. + """ + if not self.is_constant: + raise RuntimeError( + "Not accessible for non-constant shape expressions. Check is_constant to determine accessibility." + ) + return self._expr.get_constant_value() + + # Evaluate the underlying trt.IDimensionExpr, if so done lazily + @property + def _expr(self): + return self._dim_expr + +@public_api() +class SizeTensorShapeExpr(ShapeExpr): + """ + Extends :class:`ShapeExpr` + + A shape expression that represent a size tensor + + """ + def __init__(self, size_tensor_desc: "SizeTensorDesc"): + """ + .. note:: It is recommended to use :attr:`SizeTensorDesc.expr` to get a :class:`SizeTensorShapeExpr` representing a size tensor + """ + super().__init__() + self._is_size_tensor = True + self._is_dummy = size_tensor_desc.opt.is_fake + self._size_tensor_desc = size_tensor_desc + + def _op(self, op: trt.DimensionOperation, other: Union[int, "ShapeExpr"]): + raise ValueError("It is not permitted to perform binary operations on size tensor expressions") # TRT limitation + + @property + def is_constant(self): + if self._is_dummy: + raise RuntimeError( + "Not accessible for fake 'ShapeExpr's. Check is_fake to determine accessibility." + ) + return False + + @property + def _expr(self): + if self._dim_expr is not None: + return self._dim_expr + + self._dim_expr = super()._exprBuilder.declare_size_tensor(self._size_tensor_desc.index, self._size_tensor_desc.opt._expr, self._size_tensor_desc.upper_bound._expr) + return self._dim_expr + + def __repr__(self): + return f"ShapeExpr[is_size_tensor = True, id={id(self)}]" + +# Iterable holding `ShapeExpr`s +@public_api() +class ShapeExprs: + def __init__(self, length: int, _is_dummy: bool = False): + """ + Iterable holding :class:`ShapeExpr`\s + + Args: + length (int): Number of dimensions of the tensor + """ + self._length = length + self._is_dummy = _is_dummy + if _is_dummy: + self._shapes = [ShapeExpr()] * length + else: + self._shapes = [None] * length + + @classmethod + def from_tuple(cls, shape_exprs: Tuple[Union[ShapeExpr, int]]) -> "ShapeExprs": + """ + Args: + shape_exprs (Tuple[Union[ShapeExpr, int]]): Tuple to construct :class:`ShapeExprs` from + """ + shape_exprs_ = tuple([e if isinstance(e, ShapeExpr) else ShapeExpr(e) for e in shape_exprs]) + inst = cls(len(shape_exprs_)) + inst._shapes = list(shape_exprs_) + return inst + + def numel(self) -> ShapeExpr: + """ + Returns a symbolic expression for the number of elements + """ + ret = ShapeExpr(1) + for s in self._shapes: + ret *= s + return ret + + def __iter__(self): + return iter(self._shapes) + + def __getitem__(self, index): + return self._shapes[index] + + def __len__(self): + return self._length + + def __setitem__(self, index, shape): + if index >= self._length: + raise IndexError("Index out of range") + self._shapes[index] = shape + + def __repr__(self): + return f"ShapeExprs[{', '.join([s.__repr__() for s in self._shapes])}]" + + +# Numerical representation of a tensor shape +@public_api() +class Shape: + """ + Numerical representation of a tensor shape + """ + def __init__( + self, tensor_desc: Union[int, trt.DynamicPluginTensorDesc, trt.PluginTensorDesc] + ): + self._desc = tensor_desc + self._is_dynamic = None # set lazily + if isinstance(tensor_desc, trt.DynamicPluginTensorDesc): + self._length = len(tensor_desc.desc.dims) + self._shapes = tensor_desc.desc.dims + elif isinstance(tensor_desc, trt.PluginTensorDesc): + self._length = len(tensor_desc.dims) + self._shapes = tensor_desc.dims + + def numel(self) -> int: + """ + Number of elements contained + + Raises: + ValueError: When :attr:`is_dynamic` is `True` + """ + if self.is_dynamic: + raise ValueError("Shape has at least one dynamic dimension.") + return int(np.prod(self._shapes)) + + def __iter__(self): + yield from self._shapes + + def __getitem__(self, index): + return self._shapes[index] + + def __len__(self): + return self._length + + def __str__(self): + return "Shape" + str(tuple(self)) + + @property + def is_dynamic(self) -> bool: + """ + `True` if this tensor has at least one dynamic dimension, `False` otherwise. + """ + if self._is_dynamic is not None: + return self._is_dynamic + + self._is_dynamic = False + for d in self._shapes: + if d == -1: + self._is_dynamic = True + + return self._is_dynamic + + @property + def opt(self) -> Tuple[int]: + """ + Optimum value of dimensions specified for auto-tuning. + """ + if not self.is_dynamic: + raise ValueError("opt property is only accessible if is_dynamic is true") + return tuple(self._desc.opt) + + @property + def min(self) -> Tuple[int]: + """ + Lower bounds on tensor's dimensions. + """ + if not self.is_dynamic: + raise ValueError("min property is only accessible if is_dynamic is true") + return tuple(self._desc.min) + + @property + def max(self) -> Tuple[int]: + """ + Upper bounds on tensor's dimensions. + """ + if not self.is_dynamic: + raise ValueError("max property is only accessible if is_dynamic is true") + return tuple(self._desc.max) + + def __setitem__(self, index, val): + if index >= self._length: + raise IndexError("Index out of range") + self._shapes.desc[index] = val + + +# Descriptor for a tensor +# A `TensorDesc` never contains nor refers to any tensor data. +@public_api() +class TensorDesc: + """ + Descriptor for a tensor + A `TensorDesc` never contains nor refers to any tensor data. + """ + def __init__(self, shape_expr: ShapeExprs = None, dtype: trt.DataType = None, format: trt.TensorFormat = None, scale: float = None): + """ + Args: + shape_expr (ShapeExprs): The data with which to initialize the tensor. + dtype (trt.DataType): The data type of the tensor. + format (trt.TensorFormat): Format (layout) of the tensor. + scale (float): Scale for INT8 data type. + + .. code-block:: python + :linenos: + :caption: Creates a TensorDesc with constant shape expressions + + tensor = trt.TensorDesc((10, 2, 32, 32), dtype=trt.float32) + + .. code-block:: python + :linenos: + :caption: Creates a TensorDesc from shape expression of another TensorDesc + + tensor = trt.from_shape_expr(other.shape_expr, dtype=trt.float32) + """ + + # `TensorDesc` may or may not have `Shape` information but always has symbolic shape expressions and dtype + self._shape_expr = shape_expr + self._dtype = dtype + + # `shape`, `format`, and `scale` are only accessible if `has_shape`. Presently, this would be inside autotune. + self._shape = None + self._format = format + self._scale = scale + + self._aliased_to = None + self._immutable = False + + def numel(self) -> int: + """ + Returns: + Returns an int with the number of elements of the tensor. + + .. warning:: + Should only be called when TensorDesc.has_shape is true. If a symbolic expression for the number of elements is required, query TensorDesc.shape_expr.numel(). + """ + if not self.has_shape: + raise ValueError( + "TensorDesc has no shape information available at this stage. Inspect TensorDesc.has_shape to determine availability." + ) + return int(np.prod(self.shape)) + + @property + def ndim(self) -> int: + """ + Number of dimensions + """ + return len(self._shape_expr) + + @property + def is_size_tensor(self): + return False + + # Return a `TensorDesc` that has identical properties to `self` but is mutable + def like(self) -> "TensorDesc": + """ + Returns: + Returns a TensorDesc which has identical properties to this tensor, and is mutable. + + .. code-block:: python + :linenos: + :caption: Communicate that output tensor has identical properties to the input tensor + + @tensorrt.plugin.register("my::plugin") + def _(inp: tensorrt.plugin.TensorDesc) -> tensorrt.plugin.TensorDesc: + return inp.like() + """ + cloned = TensorDesc() + cloned.__dict__.update(self.__dict__) + cloned._immutable = False + return cloned + + # Return a `TensorDesc` that has identical properties to `self` AND is aliased to `self` (would result in a `Tensor` during enqueue sharing the same data buffer) + def aliased(self) -> "TensorDesc": + """ + Returns: + Returns a TensorDesc which has identical properties and is aliased to this tensor (would result in a `Tensor` during enqueue sharing the same data buffer). + Returned TensorDesc is immutable. + + .. code-block:: python + :linenos: + :caption: Communicate that output tensor has identical properties to the input tensor + + @tensorrt.plugin.register("my::plugin") + def _(inp: tensorrt.plugin.TensorDesc) -> tensorrt.plugin.TensorDesc: + return inp.aliased() + """ + cloned = TensorDesc() + cloned.__dict__.update(self.__dict__) + cloned._immutable = False + cloned._aliased_to = self + cloned._immutable = True + return cloned + + def get_aliased(self) -> "TensorDesc": + """ + Returns: + Returns a TensorDesc for the tensor which this tensor is aliased to. Returns None is this tensor is not aliased to any other tensor. + """ + return self._aliased_to + + def _validate_has_shape(self) -> None: + if not self.has_shape: + raise ValueError( + "TensorDesc has no shape information available at this stage. Inspect TensorDesc.has_shape to determine availability." + ) + + def _validate_not_immutable(self): + if hasattr(self, "_immutable") and self._immutable: + raise ValueError("Cannot modify immutable TensorDesc") + + @property + def shape_expr(self) -> ShapeExprs: + """ + Symbolic expressions for the tensor shape. + """ + return self._shape_expr + + @property + def dtype(self) -> trt.DataType: + """ + Data type of the tensor. + """ + return self._dtype + + @property + def shape(self) -> Shape: + """ + The (concrete) shape of the tensor. + + .. warning:: + Only accessible when TensorDesc.has_shape is true. + """ + self._validate_has_shape() + return self._shape + + @property + def format(self) -> trt.TensorFormat: + """ + The format of the tensor. + + .. warning:: + Only accessible when TensorDesc.has_shape is true. + """ + self._validate_has_shape() + return self._format + + @property + def scale(self) -> float: + """ + Scale for INT8 data type. + + .. warning:: + Only accessible when TensorDesc.has_shape is true. + """ + self._validate_has_shape() + return self._scale + + + @shape_expr.setter + def shape_expr(self, value): + self._shape_expr = value + + @dtype.setter + def dtype(self, value): + self._dtype = value + + @shape.setter + def shape(self, value): + self._validate_not_immutable() + self._shape = value + + @format.setter + def format(self, value): + self._validate_not_immutable() + self._format = value + + @scale.setter + def scale(self, value): + self._validate_not_immutable() + self._scale = value + + @property + def is_aliased(self) -> bool: + """ + True if this tensor is aliased to another tensor, False otherwise. + """ + return self._aliased_to is not None + + @property + def has_shape(self) -> bool: + """ + True if this tensor has concrete shape information, False otherwise. + """ + return self._shape is not None + + @property + def is_dynamic(self) -> bool: + """ + `True` if this tensor has at least one dynamic dimension, `False` otherwise. + """ + if not self.has_shape: + raise ValueError( + "TensorDesc has no shape information available at this stage. Inspect TensorDesc.has_shape to determine availability." + ) + return self.shape.is_dynamic + + @property + def has_shape_expr(self) -> bool: + """ + True if this tensor has symbolic shape expressions, False otherwise. + """ + return self.shape_expr is not None + + def __setattr__(self, name, value): + if hasattr(self, "_immutable") and self._immutable and name != "_immutable": + raise ValueError("Cannot modify immutable TensorDesc properties") + super().__setattr__(name, value) + +@public_api() +class SizeTensorDesc(TensorDesc): + """ + Extends :class:`TensorDesc` + + Descriptor for a size tensor: a scalar of either INT32 or INT64 data type used to express the extent of a data-dependent dimension. + """ + def __init__(self, opt: ShapeExpr, upper_bound: ShapeExpr): + """ + Args: + opt (ShapeExpr): Symbolic expression for the extent of this size tensor to use in the autotune process of the engine build + upper_bound (ShapeExpr): Symbolic expression for the upper-bound of this size tensor + + .. note:: It is recommended to construct a size tensor using :func:`size_tensor` instead of using this constructor directly + """ + super().__init__(ShapeExprs(0), trt.int32) + self._opt = opt + self._upper_bound = upper_bound + self._index = None + self._expr = SizeTensorShapeExpr(self) + + @property + def is_size_tensor(self): + return True + + @property + def opt(self) -> ShapeExpr: + """ + Symbolic expression for the extent of this size tensor to use in the autotune process of the engine build + """ + return self._opt + + @property + def upper_bound(self) -> ShapeExpr: + """ + Symbolic expression for the upper-bound of this size tensor + """ + return self._upper_bound + + @property + def index(self) -> int: + """ + Output index at which this size tensor resides + """ + return self._index + + def _set_index(self, idx: int): + self._index = idx + + def expr(self) -> SizeTensorShapeExpr: + """ + Symbolic expression for this size tensor + """ + return self._expr + + +# A tensor representation that carries data +@public_api() +class Tensor: + """ + Representation of a tensor that carries data + + :class:`Tensor` objects are strictly *descriptors* of a tensor with an underlying data buffer. `tensorrt.plugin` does not provide any APIs that perform standard data-altering operations on :class:`Tensor`\s. + + Supports `__cuda_array_interface__` for interoperability with other frameworks. + + """ + def __init__(self): + self._data_ptr = None + self._shape = None + self._format = None + self._dtype = None + self._scale = None + self._strides = None + + self._aliased_to = None + self._stream = None + self._read_only = None + self._immutable = False + + @property + def ndim(self) -> int: + """ + Number of dimensions + """ + return len(self._shape) + + @property + def data_ptr(self) -> int: + """ + Pointer to the data buffer of this tensor + """ + return self._data_ptr + + @property + def dtype(self) -> trt.DataType: + """ + Data type of the tensor. + """ + return self._dtype + + @property + def shape(self) -> Shape: + """ + The (concrete) shape of the tensor. + """ + return self._shape + + @property + def format(self) -> trt.TensorFormat: + """ + The format of the tensor. + """ + return self._format + + @property + def scale(self) -> float: + """ + Scale for INT8 data type. + """ + return self._scale + + @property + def strides(self) -> Tuple[int]: + """ + Strides of this tensor. + """ + return self._strides + + @data_ptr.setter + def data_ptr(self, value): + self._data_ptr = value + + @dtype.setter + def dtype(self, value): + self._dtype = value + + @shape.setter + def shape(self, value): + self._shape = value + + @format.setter + def format(self, value): + self._format = value + + @scale.setter + def scale(self, value): + self._scale = value + + @strides.setter + def strides(self, value): + self._strides = value + + def numel(self) -> int: + """ + Returns the number of elements of the tensor + + Raises: + ValueError: If the tensor has a data-dependent dimension. Examine :attr:`is_data_dependent` to determine whether the tensor is data-dependent. + + Returns: + int: Number of elements of the tensor + """ + if self.is_data_dependent: + raise ValueError( + "Tensor has a data-dependent dimension. Examine Tensor.shape to determine wildcards (representing data-dependent dimensions)." + ) + return int(np.prod(self._shape)) + + @property + def __cuda_array_interface__(self): + if self._dtype in [trt.DataType.BF16, trt.DataType.FP8, trt.DataType.INT4]: + raise ValueError( + f"Handling {self._dtype} via '__cuda_array_interface__' is not supported" + ) + + desc = { + "shape": tuple(self._shape), + "typestr": np.dtype(trt.nptype(self._dtype)).str, + } + desc["stream"] = self._stream + desc["version"] = 3 + desc["data"] = ( + self._data_ptr, + False, + ) # torch does not support read_only flag. Always set to False -- it is user's responsibility to respect implied read-write restriction(s). + desc["strides"] = tuple( + [s * np.dtype(trt.nptype(self._dtype)).itemsize for s in self._strides] + ) + + return desc + + def __setattr__(self, name, value): + if hasattr(self, "_immutable") and self._immutable and name != "_immutable": + raise ValueError("Cannot modify immutable Tensor properties") + super().__setattr__(name, value) + + def get_aliased(self) -> "Tensor": + """ + Returns: + Returns :class:`Tensor` of the tensor which this tensor is aliased to. Returns None is this tensor is not aliased to any other tensor. + """ + return self._aliased_to + + @property + def is_aliased(self): + """ + True if this tensor is aliased to another tensor, False otherwise. + """ + return self._aliased_to is None + + @property + def is_data_dependent(self): + """ + True if this tensor contains at least one data-dependent dimension, False otherwise. + """ + return self._shape.is_dynamic + + # Return a `Tensor` which has the same `data_ptr` as `self` but has the provided shape. + def aliased(self, shape: Union[Shape, Tuple[int], trt.PluginTensorDesc] = None) -> "Tensor": + """ + Return a :class:`Tensor` which has the same :attr:`data_ptr` as this but has the provided `shape`. + + Args: + shape (Union[Shape, Tuple[int], trt.PluginTensorDesc], optional): Required shape of the new tensor (must have the same volume). Defaults to same shape. + + Raises: + ValueError: If `shape` is not a supported type or if it does not have the same volume + """ + cloned = Tensor() + cloned.__dict__.update(self.__dict__) + cloned._immutable = False + if isinstance(shape, trt.PluginTensorDesc): + cloned._shape = Shape(shape) + elif isinstance(shape, Shape): + cloned._shape = shape + elif isinstance(shape, tuple): + desc = trt.PluginTensorDesc() + desc.dims = shape + desc.type = self._dtype + desc.format = self._format + desc.scale = self._scale + cloned._shape = Shape(desc) + elif shape is None: + pass + else: + raise ValueError("Unsupported type for 'shape'") + + # If either the `shape` or self._shape has a wildcard, we allow aliasing + if not self.is_data_dependent and cloned.is_data_dependent: + if cloned._shape.numel() > self.numel(): + raise ValueError("Volume of this tensor is less than the provided 'shape'.") + + cloned._aliased_to = self + return cloned diff --git a/python/packaging/bindings_wheel/tensorrt/plugin/_top_level.py b/python/packaging/bindings_wheel/tensorrt/plugin/_top_level.py new file mode 100644 index 000000000..03705e0fe --- /dev/null +++ b/python/packaging/bindings_wheel/tensorrt/plugin/_top_level.py @@ -0,0 +1,132 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Union, Tuple +import tensorrt as trt +from ._tensor import ShapeExpr, TensorDesc, ShapeExprs, SizeTensorDesc +from ._export import public_api + +# Miscellaneous top-level functions accessible through `tensorrt.plugin` + +# Performs `trt.DimensionOperation.CEIL_DIV` +@public_api() +def cdiv(first: Union[int, ShapeExpr], second: Union[int, ShapeExpr]) -> ShapeExpr: + """ + Computes symbolic ceiling division of `first` by `second` + + Args: + first (Union[int, ShapeExpr]): Dividend + second (Union[int, ShapeExpr]): Divisor + + Raises: + ValueError: If both arguments are `int`\s or if `second` evaluates to 0 + + Returns: + ShapeExpr: Symbolic expression for the ceiling division of `first` by `second` + """ + if isinstance(first, int): + if isinstance(second, int): + raise ValueError("Both arguments cannot be 'int's") + first = ShapeExpr(first) + + return first._op(trt.DimensionOperation.CEIL_DIV, second) + + +# Performs `trt.DimensionOperation.MAX` +@public_api() +def max(first: Union[int, ShapeExpr], second: Union[int, ShapeExpr]) -> ShapeExpr: + """ + Computes the maximum of `first` and `second` + + Args: + first (Union[int, ShapeExpr]): First operand + second (Union[int, ShapeExpr]): Second operand + + Raises: + ValueError: If both arguments are `int`\s + + Returns: + ShapeExpr: Symbolic expression for the maximum of `first` and `second` + """ + if isinstance(first, int): + if isinstance(second, int): + raise ValueError("Both arguments cannot be 'int's") + first = ShapeExpr(first) + + return first._op(trt.DimensionOperation.MAX, second) + + +# Performs `trt.DimensionOperation.MIN` +@public_api() +def min(first: Union[int, ShapeExpr], second: Union[int, ShapeExpr]) -> ShapeExpr: + """ + Computes the minimum of `first` and `second` + + Args: + first (Union[int, ShapeExpr]): First operand + second (Union[int, ShapeExpr]): Second operand + + Raises: + ValueError: If both arguments are `int`\s + + Returns: + ShapeExpr: Symbolic expression for the minimum of `first` and `second` + """ + if isinstance(first, int): + if isinstance(second, int): + raise ValueError("Both arguments cannot be 'int's") + first = ShapeExpr(first) + + return first._op(trt.DimensionOperation.MIN, second) + + +# Declare a size tensor descriptor with the specified autotune shape expression `opt` and `upper-bound` shape expression +@public_api() +def size_tensor(opt: ShapeExpr, upper_bound: ShapeExpr) -> SizeTensorDesc: + """ + Constructs a size tensor with the specified autotune shape expression `opt` and `upper_bound` + + Args: + opt (ShapeExpr): Symbolic expression for the extent of this size tensor to use in the autotune process of the engine build + upper_bound (ShapeExpr): Symbolic expression for the upper-bound of this size tensor + + Returns: + SizeTensorDesc: A tensor descriptor for a size tensor with the specified autotune extent and upper-bound + """ + return SizeTensorDesc(opt, upper_bound) + +# Create a TensorDesc using shape expressions and a dtype +@public_api() +def from_shape_expr(shape_expr: Union[Tuple[Union[ShapeExpr, int]], ShapeExprs], dtype: trt.DataType) -> TensorDesc: + """ + Constructs a tensor descriptor with the specified shape expression and data type + + Args: + shape_expr (Union[Tuple[Union[ShapeExpr, int]], ShapeExprs]): Expressions or constants denoting the shape of the tensor + dtype (trt.DataType): Data type of the tensor + + Returns: + TensorDesc: Tensor descriptor with the specified shape expression and data type + """ + if isinstance(shape_expr, tuple): + shape_expr_ = ShapeExprs.from_tuple(shape_expr) + else: + shape_expr_ = shape_expr + + return TensorDesc(shape_expr_, dtype) + + diff --git a/python/packaging/bindings_wheel/tensorrt/plugin/_utils.py b/python/packaging/bindings_wheel/tensorrt/plugin/_utils.py new file mode 100644 index 000000000..fd917f163 --- /dev/null +++ b/python/packaging/bindings_wheel/tensorrt/plugin/_utils.py @@ -0,0 +1,77 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorrt as trt +import numpy as np +import typing + +_numpy_to_plugin_field_type = { + np.dtype('int32'): trt.PluginFieldType.INT32, + np.dtype('int16'): trt.PluginFieldType.INT16, + np.dtype('int8'): trt.PluginFieldType.INT8, + np.dtype('bool'): trt.PluginFieldType.INT8, + np.dtype('int64'): trt.PluginFieldType.INT64, + np.dtype('float32'): trt.PluginFieldType.FLOAT32, + np.dtype('float64'): trt.PluginFieldType.FLOAT64, + np.dtype('float16'): trt.PluginFieldType.FLOAT16 +} + +_built_in_to_plugin_field_type = { + int: trt.PluginFieldType.INT64, + float: trt.PluginFieldType.FLOAT64, + bool: trt.PluginFieldType.INT8, + # str is handled separately, so not needed here +} + +def _str_to_data_type(dtype: str) -> trt.DataType: + if dtype == "FP32": + return trt.DataType.FLOAT + if dtype == "FP16": + return trt.DataType.HALF + try: + return getattr(trt.DataType, dtype) + except KeyError: + raise ValueError(f"Unknown data type string '{dtype}'") from None + + +def _join_with(lst, middle = False, delim = ", "): + if len(lst) == 0: + return "" + + ret = "" + if middle: + ret += ", " + + ret += delim.join(lst) + + return ret + +def _is_npt_ndarray(annotation): + return (typing.get_origin(annotation) == np.ndarray) or (hasattr(annotation, "__origin__") and annotation.__origin__ == np.ndarray) + +def _is_numpy_array(annotation): + return (annotation == np.ndarray) or _is_npt_ndarray(annotation) + +def _infer_numpy_type(annotation): + assert _is_npt_ndarray(annotation) + annot_args = typing.get_args(annotation) or annotation.__args__ + if len(annot_args) >= 2: + np_type = typing.get_args(annot_args[1]) or annot_args[1].__args__ + if len(np_type) >= 1: + return np_type[0] + + raise AttributeError("Improper annotation for numpy array. Annotate numpy array attributes using 'numpy.typing.NDArray[dtype]', where 'dtype' is the expected numpy dtype of the array.") diff --git a/python/packaging/bindings_wheel/tensorrt/plugin/_validate.py b/python/packaging/bindings_wheel/tensorrt/plugin/_validate.py new file mode 100644 index 000000000..48ee4fd97 --- /dev/null +++ b/python/packaging/bindings_wheel/tensorrt/plugin/_validate.py @@ -0,0 +1,357 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import inspect +import numpy as np +import typing + +from ._utils import _is_numpy_array, _join_with, _infer_numpy_type, _is_npt_ndarray +from ._tensor import TensorDesc, Tensor +from ._autotune import AutoTuneCombination + +SERIALIZABLE_BUILTIN_TYPES = (int, float, bytes, bool, str) +SERIALIZABLE_NP_DTYPES = ( + np.int8, + np.int16, + np.int32, + np.int64, + np.float16, + np.float32, + np.float64, + bool, + np.bool_, +) + +# Reserve some namespaces for future use/avoid confusion +RESERVED_NAMESPACES = { + "", + "trt", + "tensorrt", + "std", +} + +DISALLOWED_ATTR_NAMES = { + "outputs", + "stream", + "tactic", +} + +def _validate_name_and_namespace(ns: str, name: str): + if "." in ns: + raise ValueError( + f"Provided namespace {ns} cannot have any '.' in trt.plugin.register(\"{ns}::{name}\", ...)" + ) + + if "." in name: + raise ValueError( + f"Provided name {name} cannot have any '.' in trt.plugin.register(\"{ns}::{name}\", ...)" + ) + + if ns in RESERVED_NAMESPACES: + raise ValueError( + f"Provided namespace {ns} is a reserved namespace" + ) + + +# Parse `tensorrt.plugin.register` schema +def _parse_register_inputs(register_func, lazy_register): + tensor_names = [] + input_attrs = ( + dict() + ) # order is important here but for Python >= 3.7, dict respects key order + + schema_chunks = [] + + # TensorDescs and attribute args cannot be interspersed, so remember when we saw the first attribute arg + saw_first_attr = False + + # Map of (attr_name: str) -> (is_builtin_type?: bool, type annotation: str) + attrs_types = {} + + sig = inspect.signature(register_func) + + for idx, (name, param) in enumerate(sig.parameters.items()): + + if param.kind not in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ): + raise ValueError( + f"Argument {name} is not a positional-or-keyword or keyword-only arg" + ) + + # Type annotations are manadatory for `tensorrt.plugin.register` args + if param.annotation == inspect.Parameter.empty: + raise ValueError( + f"Argument {name} does not have a type annotation. Please mark as TensorDesc or one of the serializable attribute types." + ) + + # Presently, we do not support default values for attributes + if param.default is not inspect.Parameter.empty: + raise ValueError( + f"Argument {name} has a default value. Default values are not supported yet." + ) + + + if issubclass(param.annotation, TensorDesc): + if saw_first_attr: + raise ValueError( + f"TensorDescs args and attribute args cannot be interspersed. Received function with signature {sig}." + ) + + tensor_names.append(name) + schema_chunks.append(f"TensorDesc {name}") + # At this point, we don't validate attribute types since we only care about the types of serializable attributes + # However, we memorize name and type so that we may validate that the autotune function maintains consistency + else: + if idx == 0: + raise ValueError( + f"TensorDescs args should come first, followed by attributes. Received function with signature {sig}." + ) + + if name in DISALLOWED_ATTR_NAMES: + raise ValueError( + f"'{name}' is not allowed as a plugin attribute name." + ) + + if param.annotation not in SERIALIZABLE_BUILTIN_TYPES: + if _is_numpy_array(param.annotation): + if not lazy_register: + if param.annotation == np.ndarray: + raise ValueError( + "If using non-lazy registration, annotate numpy array attributes using 'numpy.typing.NDArray[dtype]', where 'dtype' is the expected numpy dtype of the array." + ) + + if _is_npt_ndarray(param.annotation): + np_dtype = _infer_numpy_type(param.annotation) + if np_dtype not in SERIALIZABLE_NP_DTYPES: + raise ValueError( + f"Attribute '{name}' is not a supported numpy array type. Supported numpy arrays type are {SERIALIZABLE_NP_DTYPES}." + ) + attrs_types[name] = (False, np_dtype) + + else: + raise ValueError( + f"Attribute '{name}' of type {param.annotation} is not a supported serializable type. Supported types are {SERIALIZABLE_BUILTIN_TYPES} or numpy arrays of type {SERIALIZABLE_NP_DTYPES}." + ) + else: + attrs_types[name] = (True, param.annotation) + + saw_first_attr = True + + schema_chunks.append(f"{param.annotation} {name}") + input_attrs[name] = param.annotation + + return ( + tensor_names, + input_attrs, + f"({_join_with(schema_chunks)})", + attrs_types, + ) + + +def _parse_register_return(register_func): + sig = inspect.signature(register_func) + + ret_annotation = sig.return_annotation + + if ret_annotation == inspect.Parameter.empty: + raise ValueError( + f"No return annotation found for register function. Received signature {sig}." + ) + + if typing.get_origin(ret_annotation) is not tuple: + if not inspect.isclass(ret_annotation) or not issubclass( + ret_annotation, TensorDesc + ): + raise ValueError( + f"Return argument is of type {ret_annotation}. Return types can only be TensorDesc or Tuple[TensorDesc]." + ) + + num_outputs = 1 + else: + args = typing.get_args(ret_annotation) + + for arg in args: + if not issubclass(arg, TensorDesc): + raise ValueError( + f"Return argument is of type {ret_annotation}. Return types can only be TensorDesc or Tuple[TensorDesc]." + ) + + num_outputs = len(args) + + return num_outputs + + +def _validate_impl(impl_func, plugin_def): + impl_attr_names = [] + found_tactic = False + + sig = inspect.signature(impl_func) + registered_attr_names = plugin_def.input_attrs.keys() + + # input arg annotations are optional, but we will validate if provided + for name, param in sig.parameters.items(): + # tactic arg is optional in impl function. If specified, remember so that we can pass it during enqueue. + if name == "tactic": + found_tactic = True + if param.annotation != inspect.Parameter.empty: + if name == "outputs": + if typing.get_origin(param.annotation) is not tuple: + raise ValueError( + f"'outputs' should be of type Tuple[Tensor]. Received {param.annotation}." + ) + args = typing.get_args(param.annotation) + for arg in args: + if not issubclass(arg, Tensor): + raise ValueError( + f"Argument for receiving output Tensor, '{name}' contains a {param.annotation}. '{name}' should be a Tuple[Tensor]." + ) + elif name == "stream": + if not issubclass(param.annotation, int): + raise ValueError("'stream' input argument should be an int") + elif name == "tactic": + if not issubclass(param.annotation, int): + raise ValueError("'tactic' input argument should be an int") + elif issubclass(param.annotation, Tensor): + if name not in plugin_def.input_tensor_names: + raise ValueError( + f"Unexpected tensor '{name}' specified in autotune function. Expected one of {plugin_def.input_tensor_names}." + ) + else: + if name not in plugin_def.input_attrs: + raise ValueError( + f"Unexpected attribute '{name}' specified in impl function. Expected one of {list(registered_attr_names)}." + ) + + if param.annotation != plugin_def.input_attrs[name]: + raise ValueError( + f"Attribute '{name}' has a type annotation different from the one specified at registration. Expected '{plugin_def.input_attrs[name]}'." + ) + + impl_attr_names.append(name) + else: + if name in plugin_def.input_attrs: + impl_attr_names.append(name) + + # Expected attribute schema should be constructed in the order they appeared in the register function + expected_attr_schema_chunks = [ + n for n in registered_attr_names if n in impl_attr_names + ] + + expected_schema = ( + "(" + + _join_with(plugin_def.input_tensor_names) + + _join_with(expected_attr_schema_chunks, True) + + ", outputs, stream" + ) + if found_tactic: + expected_schema += ", tactic)" + else: + expected_schema += ")" + + if f"({', '.join(sig.parameters.keys())})" != expected_schema: + raise ValueError( + f"Signature of the impl function '{sig}' does not match the expected input arg schema: {expected_schema}" + ) + + # Return annotation is optional, but we will validate if one is specified + if sig.return_annotation != inspect.Parameter.empty and sig.return_annotation is not None: + raise ValueError("Return annotation should be None.") + + return impl_attr_names, found_tactic + + +def _validate_autotune(autotune_func, plugin_def): + + sig = inspect.signature(autotune_func) + registered_attr_names = plugin_def.input_attrs.keys() + + autotune_attr_names = [] + + # input arg annotations are optional, but we will validate if provided + for name, param in sig.parameters.items(): + if param.annotation != inspect.Parameter.empty: + if name == "outputs": + if typing.get_origin(param.annotation) is not tuple: + raise ValueError( + f"'outputs' should be of type Tuple[TensorDesc]. Received {param.annotation}." + ) + args = typing.get_args(param.annotation) + for arg in args: + if not issubclass(arg, TensorDesc): + raise ValueError( + f"Argument for receiving output TensorDescs, '{name}' contains a {param.annotation}. '{name}' should be a Tuple[TensorDesc]." + ) + elif issubclass(param.annotation, TensorDesc): + if name not in plugin_def.input_tensor_names: + raise ValueError( + f"Unexpected tensor '{name}' specified in autotune function. Expected one of {plugin_def.input_tensor_names}." + ) + else: + if name not in plugin_def.input_attrs: + raise ValueError( + f"Unexpected attribute '{name}' specified in autotune function. Expected one of {list(registered_attr_names)}." + ) + if param.annotation != plugin_def.input_attrs[name]: + raise ValueError( + f"Attribute '{name}' has a type annotation different from the one specified at registration. Expected '{plugin_def.input_attrs[name]}'." + ) + + autotune_attr_names.append(name) + else: + if name in plugin_def.input_attrs: + autotune_attr_names.append(name) + + # Expected attribute schema should be constructed in the order they appeared in the register function + expected_attr_schema_chunks = [ + n for n in registered_attr_names if n in autotune_attr_names + ] + + expected_schema = ( + "(" + + _join_with(plugin_def.input_tensor_names) + + _join_with(expected_attr_schema_chunks, True) + + ", outputs)" + ) + + if f"({', '.join(sig.parameters.keys())})" != expected_schema: + raise ValueError( + f"Specified autotune function signature {sig} is not consistent with the expected input arg schema {expected_schema}." + ) + + ret_annotation = sig.return_annotation + + # Return annotation is optional, but we will validate if one is specified + if ret_annotation != inspect.Parameter.empty: + if typing.get_origin(ret_annotation) is not list: + if not inspect.isclass(ret_annotation) or not issubclass( + ret_annotation, AutoTuneCombination + ): + raise ValueError( + f"Return argument is of type {ret_annotation}. Return types can only be AutoTuneCombination or List[AutoTuneCombination]." + ) + else: + args = typing.get_args(ret_annotation) + + for arg in args: + if not issubclass(arg, AutoTuneCombination): + raise ValueError( + f"Return argument is of type {ret_annotation}. Return types can only be AutoTuneCombination or List[AutoTuneCombination]." + ) + + return autotune_attr_names diff --git a/python/packaging/frontend_sdist/setup.cfg b/python/packaging/frontend_sdist/setup.cfg index dea8290ce..5b78c91c0 100644 --- a/python/packaging/frontend_sdist/setup.cfg +++ b/python/packaging/frontend_sdist/setup.cfg @@ -1,12 +1,19 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# [metadata] license_files = LICENSE.txt diff --git a/python/packaging/frontend_sdist/setup.py b/python/packaging/frontend_sdist/setup.py index bb4d9172a..512d3ce16 100644 --- a/python/packaging/frontend_sdist/setup.py +++ b/python/packaging/frontend_sdist/setup.py @@ -64,7 +64,9 @@ def find_pip(): if sys.implementation.name != "cpython": raise RuntimeError("TensorRT currently only builds wheels for CPython") if platform.machine() not in ("x86_64", "AMD64", "aarch64"): - raise RuntimeError("TensorRT currently only builds wheels for x86_64 and ARM SBSA processors") + raise RuntimeError( + "TensorRT currently only builds wheels for x86_64 and ARM SBSA processors" + ) if "tegra" in platform.release(): raise RuntimeError("TensorRT does not currently build wheels for Tegra systems") @@ -167,6 +169,6 @@ def parent_command_line(): include_package_data=True, zip_safe=True, keywords="nvidia tensorrt deeplearning inference", - url="https://developer.nvidia.com/tensorrt", - download_url="https://github.com/nvidia/tensorrt/tags", + url="https://github.com/nvidia/tensorrt", + download_url="https://developer.nvidia.com/tensorrt", ) diff --git a/python/packaging/libs_wheel/setup.cfg b/python/packaging/libs_wheel/setup.cfg index dea8290ce..5b78c91c0 100644 --- a/python/packaging/libs_wheel/setup.cfg +++ b/python/packaging/libs_wheel/setup.cfg @@ -1,12 +1,19 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# [metadata] license_files = LICENSE.txt diff --git a/python/packaging/libs_wheel/setup.py b/python/packaging/libs_wheel/setup.py index b9f7af767..2d20f09fa 100644 --- a/python/packaging/libs_wheel/setup.py +++ b/python/packaging/libs_wheel/setup.py @@ -45,6 +45,6 @@ def get_requirements(): include_package_data=True, zip_safe=True, keywords="nvidia tensorrt deeplearning inference", - url="https://developer.nvidia.com/tensorrt", - download_url="https://github.com/nvidia/tensorrt/tags", + url="https://github.com/nvidia/tensorrt", + download_url="https://developer.nvidia.com/tensorrt", ) diff --git a/python/packaging/metapackage/setup.py b/python/packaging/metapackage/setup.py index 43e563e6d..11a972d6d 100644 --- a/python/packaging/metapackage/setup.py +++ b/python/packaging/metapackage/setup.py @@ -40,6 +40,6 @@ include_package_data=True, zip_safe=True, keywords="nvidia tensorrt deeplearning inference", - url="https://developer.nvidia.com/tensorrt", - download_url="https://github.com/nvidia/tensorrt/tags", + url="https://github.com/nvidia/tensorrt", + download_url="https://developer.nvidia.com/tensorrt", ) diff --git a/python/src/infer/pyCore.cpp b/python/src/infer/pyCore.cpp index c284d1d14..ae913bacb 100644 --- a/python/src/infer/pyCore.cpp +++ b/python/src/infer/pyCore.cpp @@ -165,6 +165,8 @@ static const auto runtime_deserialize_cuda_engine = [](IRuntime& self, py::buffe return self.deserializeCudaEngine(info.ptr, info.size * info.itemsize); }; + + // For ICudaEngine // TODO: Add slicing support? static const auto engine_getitem = [](ICudaEngine& self, int32_t pyIndex) { @@ -628,6 +630,7 @@ class PyStreamReader : public IStreamReader } }; + class PyDebugListener : public IDebugListener { public: @@ -1096,7 +1099,8 @@ void bindCore(py::module& m) IExecutionContextDoc::set_tensor_debug_state) .def("get_debug_state", &IExecutionContext::getDebugState, "name"_a, IExecutionContextDoc::get_debug_state) .def("set_all_tensors_debug_state", &IExecutionContext::setAllTensorsDebugState, "flag"_a, - IExecutionContextDoc::set_all_tensors_debug_state); + IExecutionContextDoc::set_all_tensors_debug_state) + ; py::enum_(m, "ExecutionContextAllocationStrategy", py::arithmetic{}, ExecutionContextAllocationStrategyDoc::descr, py::module_local()) @@ -1292,6 +1296,7 @@ void bindCore(py::module& m) "weight_streaming_scratch_memory_size", &ICudaEngine::getWeightStreamingScratchMemorySize) // End weight streaming APIs .def("is_debug_tensor", &ICudaEngine::isDebugTensor, "name"_a, ICudaEngineDoc::is_debug_tensor) + .def("__del__", &utils::doNothingDel); py::enum_(m, "AllocatorFlag", py::arithmetic{}, AllocatorFlagDoc::descr, py::module_local()) @@ -1333,6 +1338,7 @@ void bindCore(py::module& m) .def(py::init<>()) .def("read", &IStreamReader::read, "destination"_a, "size"_a, StreamReaderDoc::read); + py::enum_(m, "BuilderFlag", py::arithmetic{}, BuilderFlagDoc::descr, py::module_local()) .value("FP16", BuilderFlag::kFP16, BuilderFlagDoc::FP16) .value("BF16", BuilderFlag::kBF16, BuilderFlagDoc::BF16) @@ -1364,6 +1370,8 @@ void bindCore(py::module& m) .value("WEIGHT_STREAMING", BuilderFlag::kWEIGHT_STREAMING, BuilderFlagDoc::WEIGHT_STREAMING) .value("INT4", BuilderFlag::kINT4, BuilderFlagDoc::INT4) .value("REFIT_INDIVIDUAL", BuilderFlag::kREFIT_INDIVIDUAL, BuilderFlagDoc::REFIT_INDIVIDUAL) + .value("STRICT_NANS", BuilderFlag::kSTRICT_NANS, BuilderFlagDoc::STRICT_NANS) + .value("MONITOR_MEMORY", BuilderFlag::kMONITOR_MEMORY, BuilderFlagDoc::MONITOR_MEMORY) ; py::enum_(m, "MemoryPoolType", MemoryPoolTypeDoc::descr, py::module_local()) @@ -1526,6 +1534,8 @@ void bindCore(py::module& m) py::keep_alive<0, 1>{}) .def("build_serialized_network", &IBuilder::buildSerializedNetwork, "network"_a, "config"_a, BuilderDoc::build_serialized_network, py::call_guard{}) + .def("build_engine_with_config", &IBuilder::buildEngineWithConfig, "network"_a, "config"_a, + BuilderDoc::build_engine_with_config, py::call_guard{}) .def("is_network_supported", &IBuilder::isNetworkSupported, "network"_a, "config"_a, BuilderDoc::is_network_supported, py::call_guard{}) .def_property_readonly("logger", &IBuilder::getLogger) diff --git a/python/src/infer/pyGraph.cpp b/python/src/infer/pyGraph.cpp index e2d8564a9..b4fc21c62 100644 --- a/python/src/infer/pyGraph.cpp +++ b/python/src/infer/pyGraph.cpp @@ -19,6 +19,7 @@ #include "ForwardDeclarations.h" #include "utils.h" #include +#include #if ENABLE_INETWORK_SERIALIZE #include "NvInferSerialize.h" @@ -114,6 +115,14 @@ namespace tensorrt return self.addPluginV3(inputs.data(), inputs.size(), shapeInputs.data(), shapeInputs.size(), plugin); }; + static const auto add_plugin = [] (INetworkDefinition& self, std::tuple const&, std::vector const&, IPluginV3&> tupleInput) + { + std::vector const& inputs = std::get<0>(tupleInput); + std::vector const& shapeInputs = std::get<1>(tupleInput); + IPluginV3& plugin = std::get<2>(tupleInput); + return self.addPluginV3(inputs.data(), inputs.size(), shapeInputs.data(), shapeInputs.size(), plugin); + }; + static const auto add_convolution_nd = [](INetworkDefinition& self, ITensor& input, int32_t numOutputMaps, Dims kernelSize, Weights kernel, Weights* bias) { return self.addConvolutionNd(input, numOutputMaps, kernelSize, kernel, optionalWeights(bias)); @@ -244,6 +253,8 @@ namespace tensorrt else return py::cast(self.getBeta()); }; + + } /* lambdas */ void bindGraph(py::module& m) @@ -818,6 +829,7 @@ namespace tensorrt .def_property("compute_precision", &INormalizationLayer::getComputePrecision, &INormalizationLayer::setComputePrecision) ; + // Weights must be kept alive for the duration of the network. py::keep_alive is critical here! // Additionally, we use reference_internal so that pybind11 does not free layers when they go out of scope. py::class_(m, "INetworkDefinition", INetworkDefinitionDoc::descr, py::module_local()) @@ -893,6 +905,8 @@ namespace tensorrt INetworkDefinitionDoc::add_plugin_v2, py::return_value_policy::reference_internal) .def("add_plugin_v3", lambdas::add_plugin_v3, "inputs"_a, "shape_inputs"_a, "plugin"_a, INetworkDefinitionDoc::add_plugin_v3, py::return_value_policy::reference_internal) + .def("add_plugin", lambdas::add_plugin, "tuple"_a, + INetworkDefinitionDoc::add_plugin, py::return_value_policy::reference_internal) .def("add_parametric_relu", &INetworkDefinition::addParametricReLU, "input"_a, "slopes"_a, INetworkDefinitionDoc::add_parametric_relu, py::return_value_policy::reference_internal) .def("add_resize", &INetworkDefinition::addResize, "input"_a, INetworkDefinitionDoc::add_resize, diff --git a/python/src/infer/pyPlugin.cpp b/python/src/infer/pyPlugin.cpp index 0396fee7b..72b5d9cb0 100644 --- a/python/src/infer/pyPlugin.cpp +++ b/python/src/infer/pyPlugin.cpp @@ -17,6 +17,7 @@ // This file contains all bindings related to plugins. #include "ForwardDeclarations.h" +#include "impl/plugin.h" #include "infer/pyPluginDoc.h" #include "utils.h" #include @@ -782,7 +783,18 @@ class PyIPluginV3Impl : public IPluginV3 { if (type == PluginCapabilityType::kCORE) { - return pyResult.cast(); + try + { + return pyResult.cast(); + } + catch (py::cast_error const& e) + { + try + { + return pyResult.cast(); + } + PLUGIN_API_CATCH_CAST("get_capability_interface", " a valid core capability interface") + } } if (type == PluginCapabilityType::kBUILD) { @@ -796,12 +808,30 @@ class PyIPluginV3Impl : public IPluginV3 { return pyResult.cast(); } - PLUGIN_API_CATCH_CAST("get_capability_interface", " a valid build capability interface") + catch (py::cast_error const& e) + { + try + { + return pyResult.cast(); + } + PLUGIN_API_CATCH_CAST("get_capability_interface", " a valid build capability interface") + } } } if (type == PluginCapabilityType::kRUNTIME) { - return pyResult.cast(); + try + { + return pyResult.cast(); + } + catch (py::cast_error const& e) + { + try + { + return pyResult.cast(); + } + PLUGIN_API_CATCH_CAST("get_capability_interface", " a valid runtime capability interface") + } } } PLUGIN_API_CATCH_CAST("get_capability_interface", "nvinfer1::IPluginCapability") @@ -1339,69 +1369,674 @@ class PyIPluginV3OneBuildBaseImpl : public T mIsMetadataStringInitialized = true; } -private: - int32_t mNbOutputs{}; - int32_t mFormatCombinationLimit{}; - std::string mTimingCachedId{}; - std::string mMetadataString{}; - std::vector mTactics; +private: + int32_t mNbOutputs{}; + int32_t mFormatCombinationLimit{}; + std::string mTimingCachedId{}; + std::string mMetadataString{}; + std::vector mTactics; + + bool mIsNbOutputsInitialized{false}; + bool mIsTimingCachedIdInitialized{false}; + bool mIsFormatCombinationLimitInitialized{false}; + bool mIsMetadataStringInitialized{false}; + bool mIsTacticsInitialized{false}; +}; + +class PyIPluginV3OneBuildImpl : public PyIPluginV3OneBuildBaseImpl +{ +public: + PyIPluginV3OneBuildImpl() + : PyIPluginV3OneBuildBaseImpl(this) + { + } + PyIPluginV3OneBuildImpl(IPluginV3OneBuild const& a) + : PyIPluginV3OneBuildBaseImpl(this){}; +}; + +class PyIPluginV3OneBuildV2Impl : public PyIPluginV3OneBuildBaseImpl +{ +public: + PyIPluginV3OneBuildV2Impl() + : PyIPluginV3OneBuildBaseImpl(this) + { + } + PyIPluginV3OneBuildV2Impl(IPluginV3OneBuildV2 const& a) + : PyIPluginV3OneBuildBaseImpl(this){}; + + int32_t getAliasedInput(int32_t outputIndex) noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + + py::function pyGetAliasedInput + = py::get_override(static_cast(this), "get_aliased_input"); + + if (!pyGetAliasedInput) + { + // if no implementation is provided for get_aliased_input(), default to no aliasing + return -1; + } + + py::object pyResult = pyGetAliasedInput(outputIndex); + + try + { + auto result = pyResult.cast(); + return result; + } + PLUGIN_API_CATCH_CAST("get_aliased_input", "int32_t") + return -1; + } + PLUGIN_API_CATCH("get_aliased_input") + return -1; + } +}; + +class PyIPluginV3QuickCoreImpl : public IPluginV3QuickCore +{ +public: + using IPluginV3QuickCore::IPluginV3QuickCore; + PyIPluginV3QuickCoreImpl() = default; + PyIPluginV3QuickCoreImpl(const IPluginV3QuickCore& a){}; + + APILanguage getAPILanguage() const noexcept final + { + return APILanguage::kPYTHON; + } + + char const* getPluginName() const noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + if (!mPluginName.has_value()) + { + utils::throwPyError(PyExc_AttributeError, "plugin_name not initialized"); + } + return mPluginName.value().c_str(); + } + PLUGIN_API_CATCH("plugin_name") + return nullptr; + } + + char const* getPluginVersion() const noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + if (!mPluginVersion.has_value()) + { + utils::throwPyError(PyExc_AttributeError, "plugin_version not initialized"); + } + return mPluginVersion.value().c_str(); + } + PLUGIN_API_CATCH("plugin_version") + return nullptr; + } + + char const* getPluginNamespace() const noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + // getPluginNamespace() is not passed through to the Python side + if (!mPluginNamespace.has_value()) + { + utils::throwPyError(PyExc_AttributeError, "plugin_namespace not initialized"); + } + return mPluginNamespace.value().c_str(); + } + PLUGIN_API_CATCH("plugin_namespace") + return nullptr; + } + + void setPluginName(std::string pluginName) + { + mPluginName = std::move(pluginName); + } + + void setPluginNamespace(std::string pluginNamespace) + { + mPluginNamespace = std::move(pluginNamespace); + } + + void setPluginVersion(std::string pluginVersion) + { + mPluginVersion = std::move(pluginVersion); + } + +private: + std::optional mPluginNamespace; + std::optional mPluginName; + std::optional mPluginVersion; +}; + +class PyIPluginV3QuickBuildImpl : public IPluginV3QuickBuild +{ +public: + using IPluginV3QuickBuild::IPluginV3QuickBuild; + PyIPluginV3QuickBuildImpl() = default; + PyIPluginV3QuickBuildImpl(const IPluginV3QuickBuild& a){}; + + APILanguage getAPILanguage() const noexcept final + { + return APILanguage::kPYTHON; + } + + int32_t getNbOutputs() const noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + if (!mNbOutputs.has_value()) + { + utils::throwPyError(PyExc_AttributeError, "num_outputs not initialized"); + } + return mNbOutputs.value(); + } + PLUGIN_API_CATCH("num_outputs") + return -1; + } + + int32_t getNbTactics() noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + + try + { + py::function pyGetValidTactics + = py::get_override(static_cast(this), "get_valid_tactics"); + + if (!pyGetValidTactics) + { + // if no implementation is provided for get_valid_tactics(), communicate that no custom tactics are + // used by the plugin + return 0; + } + + py::object pyResult = pyGetValidTactics(); + mTactics = pyResult.cast>(); + return static_cast(mTactics.value().size()); + } + PLUGIN_API_CATCH_CAST("get_valid_tactics", "std::vector") + catch (py::error_already_set& e) + { + std::cerr << "[ERROR] Exception thrown from get_valid_tactics() " << e.what() << std::endl; + } + } + PLUGIN_API_CATCH("tactics") + return -1; + } + + int32_t getValidTactics(int32_t* tactics, int32_t nbTactics) noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + + try + { + // getValidTactics() must immediately follow getNbTactics() + // because it is impossible to call getValidTactics() without knowing the + // correct number of tactics. So check that mTactics.has_value() is true. + // Otherwise, something has gone wrong. + if (mTactics.has_value()) + { + if (nbTactics != static_cast(mTactics.value().size())) + { + utils::throwPyError( + PyExc_RuntimeError, "number of tactics does not match cached number of tactics"); + } + std::copy(mTactics.value().begin(), mTactics.value().end(), tactics); + // Reset to catch any subsequent violations + mTactics.reset(); + return 0; + } + else + { + utils::throwPyError( + PyExc_RuntimeError, "Internal error. getValidTactics() called before getNbTactics()."); + } + return -1; + } + PLUGIN_API_CATCH_CAST("get_valid_tactics", "std::vector") + catch (py::error_already_set& e) + { + std::cerr << "[ERROR] Exception thrown from get_valid_tactics() " << e.what() << std::endl; + } + } + PLUGIN_API_CATCH("tactics") + return -1; + } + + int32_t configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, + int32_t nbOutputs) noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + + py::function pyConfigurePlugin + = utils::getOverride(static_cast(this), "configure_plugin"); + + if (!pyConfigurePlugin) + { + utils::throwPyError(PyExc_RuntimeError, "no implementation provided for configure_plugin()"); + } + + std::vector inVector; + std::vector outVector; + std::copy_n(in, nbInputs, std::back_inserter(inVector)); + std::copy_n(out, nbOutputs, std::back_inserter(outVector)); + + try + { + pyConfigurePlugin(inVector, outVector); + return 0; + } + catch (py::error_already_set& e) + { + std::cerr << "[ERROR] Exception thrown from configure_plugin() " << e.what() << std::endl; + } + } + PLUGIN_API_CATCH("configure_plugin") + return -1; + } + + int32_t getNbSupportedFormatCombinations( + DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + + py::function pySupportsFormatCombination + = utils::getOverride(static_cast(this), "get_supported_format_combinations"); + if (!pySupportsFormatCombination) + { + utils::throwPyError( + PyExc_RuntimeError, "no implementation provided for get_supported_format_combinations()"); + } + + std::vector inOutVector; + std::copy_n(inOut, nbInputs + nbOutputs, std::back_inserter(inOutVector)); + + py::object pyResult = pySupportsFormatCombination(inOutVector, nbInputs); + try + { + mSupportedFormatCombinations = pyResult.cast>(); + if (static_cast(mSupportedFormatCombinations.value().size()) % (nbInputs + nbOutputs) != 0) + { + utils::throwPyError( + PyExc_ValueError, "Number of supported format combinations not a multiple of number of IO."); + } + return static_cast(mSupportedFormatCombinations.value().size()) / (nbInputs + nbOutputs); + } + PLUGIN_API_CATCH_CAST("get_nb_supported_format_combinations", "int32_t") + catch (py::error_already_set& e) + { + std::cerr << "[ERROR] Exception thrown from get_supported_format_combinations() " << e.what() + << std::endl; + } + return -1; + } + PLUGIN_API_CATCH("get_nb_supported_format_combinations") + return -1; + } + + int32_t getSupportedFormatCombinations(DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs, + PluginTensorDesc* supportedCombinations, int32_t nbFormatCombinations) noexcept override + { + py::gil_scoped_acquire gil{}; + + py::function pySupportsFormatCombination + = utils::getOverride(static_cast(this), "get_supported_format_combinations"); + if (!pySupportsFormatCombination) + { + utils::throwPyError( + PyExc_RuntimeError, "no implementation provided for get_supported_format_combinations()"); + } + + std::vector inOutVector; + std::copy_n(inOut, nbInputs + nbOutputs, std::back_inserter(inOutVector)); + + py::object pyResult = pySupportsFormatCombination(inOutVector, nbInputs); + + try + { + // getSupportedFormatCombinations() must immediately follow getNbSupportedFormatCombinations() + // because it is impossible to call getSupportedFormatCombinations() without knowing the + // correct number of tactics. So check that mSupportedFormatCombinations.has_value(). + // Otherwise, something has gone wrong. + if (mSupportedFormatCombinations.has_value()) + { + std::copy(mSupportedFormatCombinations.value().begin(), mSupportedFormatCombinations.value().end(), + supportedCombinations); + // Reset to catch any subsequent violations + mSupportedFormatCombinations.reset(); + return 0; + } + else + { + utils::throwPyError(PyExc_RuntimeError, + "Internal error. getSupportedFormatCombinations() called before " + "getNbSupportedFormatCombinations()."); + } + return -1; + } + PLUGIN_API_CATCH_CAST("get_supported_format_combinations", "std::vector") + catch (py::error_already_set& e) + { + std::cerr << "[ERROR] Exception thrown from get_supported_format_combinations() " << e.what() << std::endl; + } + return -1; + } + + int32_t getOutputDataTypes(DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, + int32_t const* inputRanks, int32_t nbInputs) const noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + + py::function pyGetOutputDataTypes + = utils::getOverride(static_cast(this), "get_output_data_types"); + if (!pyGetOutputDataTypes) + { + utils::throwPyError(PyExc_RuntimeError, "no implementation provided for get_output_data_types()"); + } + + std::vector inVector; + std::vector ranksVector; + std::copy_n(inputTypes, nbInputs, std::back_inserter(inVector)); + std::copy_n(inputRanks, nbInputs, std::back_inserter(ranksVector)); + + try + { + py::object pyResult = pyGetOutputDataTypes(inVector, ranksVector); + auto result = pyResult.cast>(); + + if (static_cast(result.size()) != nbOutputs) + { + utils::throwPyError(PyExc_RuntimeError, + "get_output_data_types() returned a list with a different length than num_outputs"); + } + + std::copy(result.begin(), result.end(), outputTypes); + return 0; + } + PLUGIN_API_CATCH_CAST("get_output_data_types", "std::vector") + catch (py::error_already_set& e) + { + std::cerr << "[ERROR] Exception thrown from get_output_data_types() " << e.what() << std::endl; + } + } + PLUGIN_API_CATCH("get_output_data_types") + return -1; + } + + int32_t getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs, + int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + + py::function pyGetOutputShapes + = utils::getOverride(static_cast(this), "get_output_shapes"); + if (!pyGetOutputShapes) + { + utils::throwPyError(PyExc_RuntimeError, "no implementation provided for get_output_shapes()"); + } + + std::vector inVector; + std::vector shapeInVector; + std::copy_n(inputs, nbInputs, std::back_inserter(inVector)); + std::copy_n(shapeInputs, nbShapeInputs, std::back_inserter(shapeInVector)); + + py::object pyResult = pyGetOutputShapes(inVector, shapeInVector, &exprBuilder); + + try + { + auto result = pyResult.cast>(); + if (static_cast(result.size()) != nbOutputs) + { + utils::throwPyError(PyExc_RuntimeError, + "get_output_shapes() returned a list with a different length than num_outputs"); + } + std::copy(result.begin(), result.end(), outputs); + return 0; + } + PLUGIN_API_CATCH_CAST("get_output_shapes", "std::vector") + catch (py::error_already_set& e) + { + std::cerr << "[ERROR] Exception thrown from get_output_shapes() " << e.what() << std::endl; + } + return -1; + } + PLUGIN_API_CATCH("get_output_shapes") + return -1; + } + + int32_t getAliasedInput(int32_t outputIndex) noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + + py::function pyGetAliasedInput + = py::get_override(static_cast(this), "get_aliased_input"); + + if (!pyGetAliasedInput) + { + // if no implementation is provided for get_aliased_input(), default to no aliasing + return -1; + } + + py::object pyResult = pyGetAliasedInput(outputIndex); + + try + { + auto result = pyResult.cast(); + return result; + } + PLUGIN_API_CATCH_CAST("get_aliased_input", "int32_t") + return -1; + } + PLUGIN_API_CATCH("get_aliased_input") + return -1; + } + + char const* getTimingCacheID() noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + if (!mTimingCachedId.has_value()) + { + return nullptr; + } + return mTimingCachedId.value().c_str(); + } + PLUGIN_API_CATCH("timing_cache_id") + return nullptr; + } + + char const* getMetadataString() noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + if (!mMetadataString.has_value()) + { + return nullptr; + } + return mMetadataString.value().c_str(); + } + PLUGIN_API_CATCH("metadata_string") + return nullptr; + } + + void setNbOutputs(int32_t nbOutputs) + { + mNbOutputs = nbOutputs; + } + + void setTimingCachedId(std::string timingCachedId) + { + mTimingCachedId = std::move(timingCachedId); + } + + void setMetadataString(std::string metadataString) + { + mMetadataString = std::move(metadataString); + } + +private: + std::optional mNbOutputs{}; + std::optional mTimingCachedId{}; + std::optional mMetadataString{}; + std::optional> mTactics; + std::optional> mSupportedFormatCombinations{}; +}; + +class PyIPluginV3QuickRuntimeImpl : public IPluginV3QuickRuntime +{ +public: + using IPluginV3QuickRuntime::IPluginV3QuickRuntime; + PyIPluginV3QuickRuntimeImpl() = default; + PyIPluginV3QuickRuntimeImpl(const IPluginV3QuickRuntime& a){}; + + APILanguage getAPILanguage() const noexcept final + { + return APILanguage::kPYTHON; + } + + int32_t enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, void const* const* inputs, + void* const* outputs, Dims const* inputStrides, Dims const* outputStrides, int32_t nbInputs, int32_t nbOutputs, + cudaStream_t stream) noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + + py::function pyEnqueue = utils::getOverride(static_cast(this), "enqueue"); + if (!pyEnqueue) + { + utils::throwPyError(PyExc_RuntimeError, "no implementation provided for enqueue()"); + } + + std::vector inVector; + std::vector outVector; + std::copy_n(inputDesc, nbInputs, std::back_inserter(inVector)); + std::copy_n(outputDesc, nbOutputs, std::back_inserter(outVector)); + + std::vector inPtrs; + for (int32_t idx = 0; idx < nbInputs; ++idx) + { + inPtrs.push_back(reinterpret_cast(inputs[idx])); + } + std::vector outPtrs; + for (int32_t idx = 0; idx < nbOutputs; ++idx) + { + outPtrs.push_back(reinterpret_cast(outputs[idx])); + } + + intptr_t cudaStreamPtr = reinterpret_cast(stream); - bool mIsNbOutputsInitialized{false}; - bool mIsTimingCachedIdInitialized{false}; - bool mIsFormatCombinationLimitInitialized{false}; - bool mIsMetadataStringInitialized{false}; - bool mIsTacticsInitialized{false}; -}; + std::vector inStrides; + std::vector outStrides; + std::copy_n(inputStrides, nbInputs, std::back_inserter(inStrides)); + std::copy_n(outputStrides, nbOutputs, std::back_inserter(outStrides)); -class PyIPluginV3OneBuildImpl : public PyIPluginV3OneBuildBaseImpl -{ -public: - PyIPluginV3OneBuildImpl() - : PyIPluginV3OneBuildBaseImpl(this) - { + try + { + pyEnqueue(inVector, outVector, inPtrs, outPtrs, inStrides, outStrides, cudaStreamPtr); + } + catch (py::error_already_set& e) + { + std::cerr << "[ERROR] Exception thrown from enqueue() " << e.what() << std::endl; + return -1; + } + return 0; + } + PLUGIN_API_CATCH("enqueue") + return -1; } - PyIPluginV3OneBuildImpl(IPluginV3OneBuild const& a) - : PyIPluginV3OneBuildBaseImpl(this){}; -}; -class PyIPluginV3OneBuildV2Impl : public PyIPluginV3OneBuildBaseImpl -{ -public: - PyIPluginV3OneBuildV2Impl() - : PyIPluginV3OneBuildBaseImpl(this) + int32_t setTactic(int32_t tactic) noexcept override { + try + { + py::gil_scoped_acquire gil{}; + + py::function pySetTactic = utils::getOverride(static_cast(this), "set_tactic"); + if (!pySetTactic) + { + utils::throwPyError(PyExc_RuntimeError, "no implementation provided for set_tactic()"); + } + + try + { + pySetTactic(tactic); + } + catch (py::error_already_set& e) + { + std::cerr << "[ERROR] Exception thrown from set_tactic() " << e.what() << std::endl; + return -1; + } + return 0; + } + PLUGIN_API_CATCH("set_tactic") + return -1; } - PyIPluginV3OneBuildV2Impl(IPluginV3OneBuildV2 const& a) - : PyIPluginV3OneBuildBaseImpl(this){}; - int32_t getAliasedInput(int32_t outputIndex) noexcept override + PluginFieldCollection const* getFieldsToSerialize() noexcept override { try { py::gil_scoped_acquire gil{}; - py::function pyGetAliasedInput - = py::get_override(static_cast(this), "get_aliased_input"); - - if (!pyGetAliasedInput) + py::function pyGetFieldsToSerialize + = utils::getOverride(static_cast(this), "get_fields_to_serialize"); + if (!pyGetFieldsToSerialize) { - // if no implementation is provided for get_aliased_input(), default to no aliasing - return -1; + utils::throwPyError(PyExc_RuntimeError, "no implementation provided for get_fields_to_serialize()"); } - py::object pyResult = pyGetAliasedInput(outputIndex); + py::object result = pyGetFieldsToSerialize(); try { - auto result = pyResult.cast(); - return result; + mFC = result.cast(); + return &mFC; } - PLUGIN_API_CATCH_CAST("get_aliased_input", "int32_t") - return 0U; + PLUGIN_API_CATCH_CAST("get_fields_to_serialize", "nvinfer1::PluginFieldCollection") + return nullptr; } - PLUGIN_API_CATCH("get_aliased_input") - return -1; + PLUGIN_API_CATCH("get_fields_to_serialize") + return nullptr; + } + + void setPluginType(std::string pluginType) + { + mPluginType = std::move(pluginType); + } + + void setPluginVersion(std::string pluginVersion) + { + mPluginVersion = std::move(pluginVersion); } + +private: + PluginFieldCollection mFC; + std::optional mNamespace; + std::optional mPluginType; + std::optional mPluginVersion; }; class PyIPluginV3OneRuntimeImpl : public IPluginV3OneRuntime @@ -1583,7 +2218,8 @@ class PyIPluginV3OneRuntimeImpl : public IPluginV3OneRuntime try { - return result.cast(); + mFC = result.cast(); + return &mFC; } PLUGIN_API_CATCH_CAST("get_fields_to_serialize", "nvinfer1::PluginFieldCollection") return nullptr; @@ -1610,6 +2246,7 @@ class PyIPluginV3OneRuntimeImpl : public IPluginV3OneRuntime std::string mNamespace; std::string mPluginType; std::string mPluginVersion; + PluginFieldCollection mFC; bool mIsNbOutputsInitialized{false}; bool mIsNamespaceInitialized{false}; @@ -1835,6 +2472,132 @@ class IPluginCreatorV3OneImpl : public IPluginCreatorV3One bool mIsPluginVersionInitialized{false}; }; +class IPluginCreatorV3QuickImpl : public IPluginCreatorV3Quick +{ +public: + IPluginCreatorV3QuickImpl() = default; + + APILanguage getAPILanguage() const noexcept final + { + return APILanguage::kPYTHON; + } + + char const* getPluginName() const noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + if (!mName.has_value()) + { + utils::throwPyError(PyExc_AttributeError, "name not initialized"); + } + return mName.value().c_str(); + } + PLUGIN_API_CATCH("name") + return nullptr; + } + + char const* getPluginVersion() const noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + if (!mPluginVersion.has_value()) + { + utils::throwPyError(PyExc_AttributeError, "plugin_version not initialized"); + } + return mPluginVersion.value().c_str(); + } + PLUGIN_API_CATCH("plugin_version") + return nullptr; + } + + PluginFieldCollection const* getFieldNames() noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + if (!mFC.has_value()) + { + utils::throwPyError(PyExc_AttributeError, "field_names not initialized"); + } + return &mFC.value(); + } + PLUGIN_API_CATCH("field_names") + return nullptr; + } + + IPluginV3* createPlugin( + char const* name, char const* nspace, const PluginFieldCollection* fc, TensorRTPhase phase) noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + + py::function pyCreatePlugin + = utils::getOverride(static_cast(this), "create_plugin"); + if (!pyCreatePlugin) + { + utils::throwPyError(PyExc_RuntimeError, "no implementation provided for create_plugin()"); + } + + std::string nameString{name}; + std::string namespaceString{nspace}; + + py::handle handle = pyCreatePlugin(nameString, namespaceString, fc, phase).release(); + try + { + return handle.cast(); + } + PLUGIN_API_CATCH_CAST("create_plugin", "IPluginV3*") + return nullptr; + } + PLUGIN_API_CATCH("create_plugin") + return nullptr; + } + + char const* getPluginNamespace() const noexcept override + { + try + { + py::gil_scoped_acquire gil{}; + if (!mNamespace.has_value()) + { + utils::throwPyError(PyExc_AttributeError, "plugin_namespace not initialized"); + } + return mNamespace.value().c_str(); + } + PLUGIN_API_CATCH("plugin_namespace") + return nullptr; + } + + void setFieldNames(PluginFieldCollection fc) + { + mFC = fc; + } + + void setPluginName(std::string name) + { + mName = std::move(name); + } + + void setPluginVersion(std::string pluginVersion) + { + mPluginVersion = std::move(pluginVersion); + } + + void setPluginNamespace(std::string pluginNamespace) + { + mNamespace = std::move(pluginNamespace); + } + +private: + std::optional mFC; + std::optional mNamespace; + std::optional mName; + std::optional mPluginVersion; +}; + namespace { bool isPython(IVersionedInterface const& versionedInterface) @@ -1988,6 +2751,10 @@ static const auto get_all_creators = [](IPluginRegistry& self) -> std::vector(ptr[i++])); } + if (std::strcmp(ptr[i]->getInterfaceInfo().kind, "PLUGIN CREATOR_V3QUICK") == 0) + { + return py::cast(static_cast(ptr[i++])); + } utils::throwPyError(PyExc_RuntimeError, "Unknown plugin creator type"); return py::none{}; }); @@ -2017,7 +2784,14 @@ static const auto get_capability_interface = [](IPluginV3& self, PluginCapabilit { if (type == PluginCapabilityType::kCORE) { - return py::cast(static_cast(capability_interface)); + try + { + return py::cast(static_cast(capability_interface)); + } + catch (py::cast_error const& e) + { + return py::cast(static_cast(capability_interface)); + } } if (type == PluginCapabilityType::kBUILD) { @@ -2031,12 +2805,26 @@ static const auto get_capability_interface = [](IPluginV3& self, PluginCapabilit { return py::cast(static_cast(capability_interface)); } - PLUGIN_API_CATCH_CAST("get_capability_interface", " a valid build capability interface") + catch (py::cast_error const& e) + { + try + { + return py::cast(static_cast(capability_interface)); + } + PLUGIN_API_CATCH_CAST("get_capability_interface", " a valid build capability interface") + } } } if (type == PluginCapabilityType::kRUNTIME) { - return py::cast(static_cast(capability_interface)); + try + { + return py::cast(static_cast(capability_interface)); + } + catch (py::cast_error const& e) + { + return py::cast(static_cast(capability_interface)); + } } } PLUGIN_API_CATCH_CAST("get_capability_interface", "nvinfer1::IPluginCapability") @@ -2063,6 +2851,10 @@ static const auto get_creator = [](IPluginRegistry& self, char const* pluginType { return py::cast(static_cast(creator)); } + if (std::strcmp(creator->getInterfaceInfo().kind, "PLUGIN CREATOR_V3QUICK") == 0) + { + return py::cast(static_cast(creator)); + } utils::throwPyError(PyExc_RuntimeError, "Unknown plugin creator type"); return py::none{}; } @@ -2079,6 +2871,10 @@ static const auto creator_create_plugin_v3 return self.createPlugin(name.c_str(), fc, phase); }; +static const auto creator_create_plugin_v3_quick = + [](IPluginCreatorV3Quick& self, std::string const& name, std::string const& nspace, PluginFieldCollection const* fc, + TensorRTPhase phase) { return self.createPlugin(name.c_str(), nspace.c_str(), fc, phase); }; + static const auto deserialize_plugin = [](IPluginCreator& self, std::string const& name, py::buffer& serializedPlugin) { py::buffer_info info = serializedPlugin.request(); return self.deserializePlugin(name.c_str(), info.ptr, info.size * info.itemsize); @@ -2170,6 +2966,16 @@ static const auto IPluginV3_set_num_outputs = [](IPluginV3OneBuild& self, int32_ utils::throwPyError(PyExc_AttributeError, "Can't set attribute: num_outputs is read-only for C++ plugins"); }; +static const auto IPluginV3_quick_set_num_outputs = [](IPluginV3QuickBuild& self, int32_t numOutputs) { + if (isPython(self)) + { + auto plugin = static_cast(&self); + plugin->setNbOutputs(numOutputs); + return; + } + utils::throwPyError(PyExc_AttributeError, "Can't set attribute: num_outputs is read-only for C++ plugins"); +}; + } // namespace lambdas namespace helpers @@ -2635,6 +3441,35 @@ void bindPlugin(py::module& m) // The following defs are only for documenting the API for Python-based plugins .def("get_aliased_input", &pluginDoc::getAliasedInput, IPluginV3Doc::get_valid_tactics); + py::class_>(m, "IPluginV3QuickCore", py::module_local()) + .def(py::init<>()) + .def(py::init()) + .def_property("plugin_name", &IPluginV3QuickCore::getPluginName, + py::cpp_function( + &helpers::setPluginName, py::keep_alive<1, 2>{})) + .def_property("plugin_version", &IPluginV3QuickCore::getPluginVersion, + py::cpp_function( + &helpers::setPluginVersion, py::keep_alive<1, 2>{})) + .def_property("plugin_namespace", &IPluginV3QuickCore::getPluginNamespace, + py::cpp_function( + &helpers::setPluginNamespace, py::keep_alive<1, 2>{})); + + py::class_>(m, "IPluginV3QuickBuild", py::module_local()) + .def(py::init<>()) + .def(py::init()) + .def_property("num_outputs", &IPluginV3QuickBuild::getNbOutputs, lambdas::IPluginV3_quick_set_num_outputs) + .def_property("metadata_string", &IPluginV3QuickBuild::getMetadataString, + py::cpp_function(lambdas::IPluginV3_get_metadata_string, py::keep_alive<1, 2>{})) + .def_property("timing_cache_id", &IPluginV3QuickBuild::getTimingCacheID, + py::cpp_function(lambdas::IPluginV3_get_timing_cache_id, py::keep_alive<1, 2>{})); + + py::class_>(m, "IPluginV3QuickRuntime", py::module_local()) + .def(py::init<>()) + .def(py::init()); + py::class_>( m, "IPluginV3OneRuntime", IPluginV3Doc::ipluginv3oneruntime_descr, py::module_local()) @@ -2692,6 +3527,25 @@ void bindPlugin(py::module& m) .def("create_plugin", lambdas::creator_create_plugin_v3, "name"_a, "field_collection"_a, "phase"_a, IPluginCreatorV3OneDoc::create_plugin); + py::class_( + m, "IPluginCreatorV3Quick", py::module_local()) + .def(py::init<>()) + .def_property("name", &IPluginCreatorV3Quick::getPluginName, + py::cpp_function( + &helpers::setPluginName, py::keep_alive<1, 2>{})) + .def_property("plugin_version", &IPluginCreatorV3Quick::getPluginVersion, + py::cpp_function( + &helpers::setPluginVersion, py::keep_alive<1, 2>{})) + .def_property("field_names", &helpers::getFieldNames, + py::cpp_function(&helpers::setPluginCreatorFieldNames, + py::keep_alive<1, 2>{}), + py::return_value_policy::reference_internal) + .def_property("plugin_namespace", &IPluginCreatorV3Quick::getPluginNamespace, + py::cpp_function( + &helpers::setPluginNamespace, py::keep_alive<1, 2>{})) + .def("create_plugin", lambdas::creator_create_plugin_v3_quick, "name"_a, "namespace"_a, "field_collection"_a, + "phase"_a); + py::class_>( m, "IPluginResourceContext", IPluginResourceContextDoc::descr, py::module_local()) // return_value_policy::reference_internal is default for the following @@ -2742,8 +3596,18 @@ void bindPlugin(py::module& m) .value("V1", PluginCreatorVersion::kV1) .value("V1_PYTHON", PluginCreatorVersion::kV1_PYTHON); - m.def("get_plugin_registry", &getPluginRegistry, py::return_value_policy::reference, - FreeFunctionsDoc::get_plugin_registry); + m.add_object("_plugin_registry", py::none()); + + m.def( + "get_plugin_registry", + [m]() { + if (m.attr("_plugin_registry").is_none()) + { + m.attr("_plugin_registry") = py::cast(getPluginRegistry()); + } + return m.attr("_plugin_registry"); + }, + py::return_value_policy::reference, FreeFunctionsDoc::get_plugin_registry); py::enum_( m, "PluginCapabilityType", py::arithmetic{}, PluginCapabilityTypeDoc::descr, py::module_local()) diff --git a/samples/common/sampleDevice.cpp b/samples/common/sampleDevice.cpp index 7964aeb5d..e9ad78dd2 100644 --- a/samples/common/sampleDevice.cpp +++ b/samples/common/sampleDevice.cpp @@ -31,6 +31,7 @@ void cudaCheck(cudaError_t ret, std::ostream& err) } } +#if !TRT_WINML // Construct GPU UUID string in the same format as nvidia-smi does. std::string getUuidString(cudaUUID_t uuid) { @@ -54,7 +55,6 @@ std::string getUuidString(cudaUUID_t uuid) void setCudaDevice(int32_t device, std::ostream& os) { -#if !TRT_WINML os << "=== Device Information ===" << std::endl; // Get the number of visible GPUs. @@ -113,7 +113,6 @@ void setCudaDevice(int32_t device, std::ostream& os) os << "Note: The application clock rates do not reflect the actual clock rates that the GPU is " << "currently running at." << std::endl; // clang-format on -#endif } int32_t getCudaDriverVersion() @@ -129,5 +128,6 @@ int32_t getCudaRuntimeVersion() cudaCheck(cudaRuntimeGetVersion(&version)); return version; } +#endif } // namespace sample diff --git a/samples/common/sampleDevice.h b/samples/common/sampleDevice.h index 986dccb41..ef6a00a25 100644 --- a/samples/common/sampleDevice.h +++ b/samples/common/sampleDevice.h @@ -532,6 +532,7 @@ class OutputAllocator : public nvinfer1::IOutputAllocator nvinfer1::Dims mFinalDims; }; +#if !TRT_WINML //! Set the GPU to run the inference on. void setCudaDevice(int32_t device, std::ostream& os); @@ -541,6 +542,8 @@ int32_t getCudaDriverVersion(); //! Get the CUDA version of the current CUDA runtime. int32_t getCudaRuntimeVersion(); +#endif + } // namespace sample #endif // TRT_SAMPLE_DEVICE_H diff --git a/samples/common/sampleEngines.cpp b/samples/common/sampleEngines.cpp index 51d1329cd..5dddceeb0 100644 --- a/samples/common/sampleEngines.cpp +++ b/samples/common/sampleEngines.cpp @@ -909,6 +909,11 @@ bool setupNetworkAndConfig(BuildOptions const& build, SystemOptions const& sys, } } + if (build.enableMonitorMemory) + { + config.setFlag(BuilderFlag::kMONITOR_MEMORY); + } + config.setProfilingVerbosity(build.profilingVerbosity); config.setAvgTimingIterations(build.avgTiming); @@ -1305,6 +1310,7 @@ bool loadStreamingEngineToBuildEnv(std::string const& filepath, BuildEnvironment return true; } + bool loadEngineToBuildEnv(std::string const& filepath, BuildEnvironment& env, std::ostream& err) { auto const tBegin = std::chrono::high_resolution_clock::now(); @@ -1644,9 +1650,9 @@ namespace void* initSafeRuntime() { void* handle{nullptr}; - // Currently libsafe_executor_debug.so for samplesCommon::isDebug() is not ready. + // Currently libnvinfer_safe_debug.so for samplesCommon::isDebug() is not ready. #if !defined(_WIN32) - std::string const dllName{"libsafe_executor.so"}; + std::string const dllName{"libnvinfer_safe.so"}; #if SANITIZER_BUILD handle = dlopen(dllName.c_str(), RTLD_LAZY | RTLD_NODELETE); #else diff --git a/samples/common/sampleEngines.h b/samples/common/sampleEngines.h index ec02e9097..d1d88319e 100644 --- a/samples/common/sampleEngines.h +++ b/samples/common/sampleEngines.h @@ -159,6 +159,7 @@ class LazilyDeserializedEngine return *mFileReader; } + //! //! \brief Get if safe mode is enabled. //! diff --git a/samples/common/sampleInference.cpp b/samples/common/sampleInference.cpp index ca0098d42..77a99c1d1 100644 --- a/samples/common/sampleInference.cpp +++ b/samples/common/sampleInference.cpp @@ -256,6 +256,11 @@ bool setUpInference(InferenceEnvironment& iEnv, InferenceOptions const& inferenc // Release serialized blob to save memory space. iEnv.engine.releaseBlob(); +#if TRT_WINML + // Start JIT Compilation time after engine deserialization + auto jitCompileBegin = std::chrono::high_resolution_clock::now(); +#endif + // Setup weight streaming if enabled if (engine->getStreamableWeightsSize() > 0) { @@ -502,8 +507,17 @@ bool setUpInference(InferenceEnvironment& iEnv, InferenceOptions const& inferenc } auto const* context = iEnv.contexts.front().get(); - return FillStdBindings( + bool fillBindingsSuccess = FillStdBindings( engine, context, inference.inputs, iEnv.bindings, 1, endBindingIndex, inference.optProfileIndex)(); + +#if TRT_WINML + // Stop JIT Compile Time when setup for inference is complete + auto jitCompileEnd = std::chrono::high_resolution_clock::now(); + sample::gLogInfo << "JIT Compilation in " << std::chrono::duration(jitCompileEnd - jitCompileBegin).count() + << " sec." << std::endl; +#endif + + return fillBindingsSuccess; } TaskInferenceEnvironment::TaskInferenceEnvironment( @@ -1169,18 +1183,22 @@ bool timeDeserialize(InferenceEnvironment& iEnv, SystemOptions const& sys) bool deserializeOK{false}; engine.reset(nullptr); auto startClock = std::chrono::high_resolution_clock::now(); + SMP_RETVAL_IF_FALSE(!iEnv.safe, "Safe inference is not supported!", false, sample::gLogError); - auto& reader = iEnv.engine.getFileReader(); - reader.reset(); - ASSERT(reader.isOpen()); #if !TRT_WINML for (auto const& pluginPath : sys.dynamicPlugins) { rt->getPluginRegistry().loadLibrary(pluginPath.c_str()); } #endif + + auto& reader = iEnv.engine.getFileReader(); + ASSERT(reader.isOpen()); + reader.reset(); engine.reset(rt->deserializeCudaEngine(reader)); + deserializeOK = (engine != nullptr); + deserializeOK = (engine != nullptr); auto endClock = std::chrono::high_resolution_clock::now(); // return NAN if deserialization failed. diff --git a/samples/common/sampleOptions.cpp b/samples/common/sampleOptions.cpp index bdb1b21c3..283091f1b 100644 --- a/samples/common/sampleOptions.cpp +++ b/samples/common/sampleOptions.cpp @@ -1222,6 +1222,7 @@ void BuildOptions::parse(Arguments& arguments) getAndDelOption(arguments, "--excludeLeanRuntime", excludeLeanRuntime); getAndDelOption(arguments, "--noCompilationCache", disableCompilationCache); + getAndDelOption(arguments, "--monitorMemory", enableMonitorMemory); getAndDelNegOption(arguments, "--noTF32", tf32); getAndDelOption(arguments, "--fp16", fp16); getAndDelOption(arguments, "--bf16", bf16); @@ -2175,6 +2176,7 @@ std::ostream& operator<<(std::ostream& os, const BuildOptions& options) "timingCacheMode: "; printTimingCache(os, options.timingCacheMode) << std::endl << "timingCacheFile: " << options.timingCacheFile << std::endl << "Enable Compilation Cache: "<< boolToEnabled(!options.disableCompilationCache) << std::endl << + "Enable Monitor Memory: "<< boolToEnabled(options.enableMonitorMemory) << std::endl << "errorOnTimingCacheMiss: " << boolToEnabled(options.errorOnTimingCacheMiss) << std::endl << "Preview Features: "; printPreviewFlags(os, options) << std::endl << "MaxAuxStreams: " << options.maxAuxStreams << std::endl << @@ -2475,6 +2477,7 @@ void BuildOptions::help(std::ostream& os) " --excludeLeanRuntime When --versionCompatible is enabled, this flag indicates that the generated engine should" "\n" " not include an embedded lean runtime. If this is set, the user must explicitly specify a" "\n" " valid lean runtime to use when loading the engine." "\n" + " --monitorMemory Enable memory monitor report for debugging usage. (default = disabled)" "\n" " --sparsity=spec Control sparsity (default = disabled). " "\n" R"( Sparsity: spec ::= "disable", "enable", "force")" "\n" " Note: Description about each of these options is as below" "\n" diff --git a/samples/common/sampleOptions.h b/samples/common/sampleOptions.h index 8ca0a655b..83e11fc46 100644 --- a/samples/common/sampleOptions.h +++ b/samples/common/sampleOptions.h @@ -238,6 +238,7 @@ class BuildOptions : public Options bool pluginInstanceNorm{false}; bool excludeLeanRuntime{false}; bool disableCompilationCache{false}; + bool enableMonitorMemory{false}; int32_t builderOptimizationLevel{defaultBuilderOptimizationLevel}; int32_t maxTactics{defaultMaxTactics}; SparsityFlag sparsity{SparsityFlag::kDISABLE}; diff --git a/samples/common/sampleUtils.h b/samples/common/sampleUtils.h index 6cd4280b9..5d1912199 100644 --- a/samples/common/sampleUtils.h +++ b/samples/common/sampleUtils.h @@ -78,9 +78,11 @@ std::vector splitToStringVec(std::string const& option, char separa bool broadcastIOFormats(std::vector const& formats, size_t nbBindings, bool isInput = true); +#if !TRT_WINML int32_t getCudaDriverVersion(); int32_t getCudaRuntimeVersion(); +#endif void sparsify(nvinfer1::INetworkDefinition& network, std::vector>& sparseWeights); void sparsify(nvinfer1::Weights const& weights, int32_t k, int32_t rs, std::vector& sparseWeights); diff --git a/samples/common/streamReader.h b/samples/common/streamReader.h index 7d4aa1c6e..8d7f78fff 100644 --- a/samples/common/streamReader.h +++ b/samples/common/streamReader.h @@ -60,7 +60,7 @@ class FileStreamReader final : public nvinfer1::IStreamReader void reset() { - assert(mFile.good()); + ASSERT(mFile.good()); mFile.seekg(0); } @@ -73,6 +73,7 @@ class FileStreamReader final : public nvinfer1::IStreamReader std::ifstream mFile; }; + } // namespace samplesCommon #endif // STREAM_READER_H diff --git a/samples/python/detectron2/requirements.txt b/samples/python/detectron2/requirements.txt index d355b4912..d9dcdc999 100644 --- a/samples/python/detectron2/requirements.txt +++ b/samples/python/detectron2/requirements.txt @@ -5,7 +5,7 @@ Pillow>=10.0.0 git+https://github.com/facebookresearch/detectron2.git git+https://github.com/NVIDIA/TensorRT#subdirectory=tools/onnx-graphsurgeon cuda-python==12.2.0; python_version <= "3.10" -cuda-python==12.5.0; python_version >= "3.11" +cuda-python==12.6.0; python_version >= "3.11" pywin32; platform_system == "Windows" pyyaml==6.0.1 requests==2.32.2 diff --git a/samples/python/downloader.py b/samples/python/downloader.py index c4240b3d7..3c1e1e046 100755 --- a/samples/python/downloader.py +++ b/samples/python/downloader.py @@ -94,7 +94,7 @@ def _downloadFile(path, url): session.mount("http://", HTTPAdapter(max_retries=retries)) session.mount("https://", HTTPAdapter(max_retries=retries)) try: - r = session.get(url, stream=True, timeout=30) + r = session.get(url, stream=True, timeout=60) if r.status_code == 200: logger.info("Connecting to %s is successful.", url) @@ -108,11 +108,18 @@ def _downloadFile(path, url): progress_bar.update(len(chunk)) fd.write(chunk) progress_bar.close() + return True else: logger.info("Failed to connect to %s with status code: %s.", url, r.status_code) - + return False + + except requests.exceptions.ConnectionError as e: + logger.debug("Connection failed after retries:", e) + except requests.exceptions.Timeout as e: + logger.debug("A timeout occurred:", e) except requests.exceptions.RequestException as e: logger.debug("Error occurred while requesting connection to %s: %s.", url, e) + return False allGood = True for f in sample_data.files: @@ -130,7 +137,7 @@ def _downloadFile(path, url): allGood = False continue _createDirIfNeeded(fpath) - _downloadFile(fpath, f.url) + assert _downloadFile(fpath, f.url) if not _checkMD5(fpath, f.checksum): logger.error("The downloaded file %s has a different checksum!", fpath) allGood = False diff --git a/samples/python/efficientdet/requirements.txt b/samples/python/efficientdet/requirements.txt index c9e040ece..dfed86b86 100644 --- a/samples/python/efficientdet/requirements.txt +++ b/samples/python/efficientdet/requirements.txt @@ -6,7 +6,7 @@ onnxruntime==1.18.1; python_version >= "3.11" tf2onnx==1.8.1; python_version <= "3.10" tf2onnx==1.16.0; python_version >= "3.11" cuda-python==12.2.0; python_version <= "3.10" -cuda-python==12.5.0; python_version >= "3.11" +cuda-python==12.6.0; python_version >= "3.11" pywin32; platform_system == "Windows" pyyaml==6.0.1 requests==2.32.2 diff --git a/samples/python/efficientnet/requirements.txt b/samples/python/efficientnet/requirements.txt index 751c0789a..83b9ea7c3 100644 --- a/samples/python/efficientnet/requirements.txt +++ b/samples/python/efficientnet/requirements.txt @@ -5,7 +5,7 @@ tensorrt>=7.1.0.0 tf2onnx==1.8.1; python_version <= "3.10" tf2onnx==1.16.0; python_version >= "3.11" cuda-python==12.2.0; python_version <= "3.10" -cuda-python==12.5.0; python_version >= "3.11" +cuda-python==12.6.0; python_version >= "3.11" pywin32; platform_system == "Windows" pyyaml==6.0.1 requests==2.32.2 diff --git a/samples/python/engine_refit_onnx_bidaf/data_processing.py b/samples/python/engine_refit_onnx_bidaf/data_processing.py index f6740bc57..7eb052adc 100644 --- a/samples/python/engine_refit_onnx_bidaf/data_processing.py +++ b/samples/python/engine_refit_onnx_bidaf/data_processing.py @@ -24,9 +24,9 @@ def preprocess(text): try: - nltk.data.find("tokenizers/punkt") + nltk.data.find("tokenizers/punkt_tab") except LookupError: - nltk.download("punkt") + nltk.download("punkt_tab") tokens = word_tokenize(text) # split into lower-case word tokens, in numpy array with shape of (seq, 1) words = np.asarray([w.lower() for w in tokens]).reshape(-1, 1) diff --git a/samples/python/engine_refit_onnx_bidaf/requirements.txt b/samples/python/engine_refit_onnx_bidaf/requirements.txt index 84469f7a7..c1c9c715a 100644 --- a/samples/python/engine_refit_onnx_bidaf/requirements.txt +++ b/samples/python/engine_refit_onnx_bidaf/requirements.txt @@ -1,8 +1,8 @@ onnx==1.16.0 -nltk==3.8.1 +nltk==3.9.1 wget==3.2 cuda-python==12.2.0; python_version <= "3.10" -cuda-python==12.5.0; python_version >= "3.11" +cuda-python==12.6.0; python_version >= "3.11" pywin32; platform_system == "Windows" pyyaml==6.0.1 requests==2.32.2 diff --git a/samples/python/introductory_parser_samples/requirements.txt b/samples/python/introductory_parser_samples/requirements.txt index 01b57c060..fc537473f 100644 --- a/samples/python/introductory_parser_samples/requirements.txt +++ b/samples/python/introductory_parser_samples/requirements.txt @@ -1,6 +1,6 @@ Pillow>=10.0.0 cuda-python==12.2.0; python_version <= "3.10" -cuda-python==12.5.0; python_version >= "3.11" +cuda-python==12.6.0; python_version >= "3.11" pywin32; platform_system == "Windows" pyyaml==6.0.1 requests==2.32.2 diff --git a/samples/python/network_api_pytorch_mnist/README.md b/samples/python/network_api_pytorch_mnist/README.md index 1f8dba76f..c5fdfb0ca 100644 --- a/samples/python/network_api_pytorch_mnist/README.md +++ b/samples/python/network_api_pytorch_mnist/README.md @@ -15,7 +15,7 @@ ## Description -This sample, `network_api_pytorch_mnist`, trains a convolutional model on the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset and runs inference with a TensorRT engine. +This sample, `network_api_pytorch_mnist`, trains a convolutional model on the [MNIST](https://ossci-datasets.s3.amazonaws.com/mnist/) dataset and runs inference with a TensorRT engine. ## How does this sample work? @@ -79,7 +79,7 @@ The following resources provide a deeper understanding about getting started wit - [MNIST model](https://github.com/pytorch/examples/tree/master/mnist) **Dataset** -- [MNIST database](http://yann.lecun.com/exdb/mnist/) +- [MNIST database](https://ossci-datasets.s3.amazonaws.com/mnist/) **Documentation** - [Introduction To NVIDIA’s TensorRT Samples](https://docs.nvidia.com/deeplearning/sdk/tensorrt-sample-support-guide/index.html#samples) diff --git a/samples/python/network_api_pytorch_mnist/requirements.txt b/samples/python/network_api_pytorch_mnist/requirements.txt index 153cbfbde..71ef1a17f 100644 --- a/samples/python/network_api_pytorch_mnist/requirements.txt +++ b/samples/python/network_api_pytorch_mnist/requirements.txt @@ -1,12 +1,8 @@ Pillow>=10.0.0 --f https://download.pytorch.org/whl/torch_stable.html -torch==2.0.0; (platform_machine=="aarch64" and sys.platform=="linux") -torch==2.2.1+cpu; ((platform_machine=="x86_64" and sys.platform=="linux") or sys.platform=="win32") --f https://download.pytorch.org/whl/torch_stable.html -torchvision==0.15.1; (platform_machine=="aarch64" and sys.platform=="linux") -torchvision==0.17.1+cpu; ((platform_machine=="x86_64" and sys.platform=="linux") or sys.platform=="win32") +torch +torchvision cuda-python==12.2.0; python_version <= "3.10" -cuda-python==12.5.0; python_version >= "3.11" +cuda-python==12.6.0; python_version >= "3.11" pywin32; platform_system == "Windows" pyyaml==6.0.1 requests==2.32.2 diff --git a/samples/python/non_zero_plugin/requirements.txt b/samples/python/non_zero_plugin/requirements.txt index 8f84ea544..595c3d8d3 100644 --- a/samples/python/non_zero_plugin/requirements.txt +++ b/samples/python/non_zero_plugin/requirements.txt @@ -1,5 +1,5 @@ cuda-python==12.2.0; python_version <= "3.10" -cuda-python==12.5.0; python_version >= "3.11" +cuda-python==12.6.0; python_version >= "3.11" cupy-cuda12x torch --extra-index-url https://pypi.ngc.nvidia.com diff --git a/samples/python/onnx_custom_plugin/requirements.txt b/samples/python/onnx_custom_plugin/requirements.txt index 4be5b282b..34c96c857 100644 --- a/samples/python/onnx_custom_plugin/requirements.txt +++ b/samples/python/onnx_custom_plugin/requirements.txt @@ -1,10 +1,10 @@ -nltk==3.8.1 +nltk==3.9.1 onnx==1.16.0 --extra-index-url https://pypi.ngc.nvidia.com onnx-graphsurgeon>=0.3.20 wget>=3.2 cuda-python==12.2.0; python_version <= "3.10" -cuda-python==12.5.0; python_version >= "3.11" +cuda-python==12.6.0; python_version >= "3.11" pywin32; platform_system == "Windows" pyyaml==6.0.1 requests==2.32.2 diff --git a/samples/python/onnx_packnet/requirements.txt b/samples/python/onnx_packnet/requirements.txt index 1cb163578..d672ad99d 100644 --- a/samples/python/onnx_packnet/requirements.txt +++ b/samples/python/onnx_packnet/requirements.txt @@ -1,12 +1,8 @@ onnx==1.16.0 --extra-index-url https://pypi.ngc.nvidia.com onnx-graphsurgeon>=0.3.20 --f https://download.pytorch.org/whl/torch_stable.html -torch==2.0.0; (platform_machine=="aarch64" and sys.platform=="linux") -torch==2.2.1+cpu; ((platform_machine=="x86_64" and sys.platform=="linux") or sys.platform=="win32") --f https://download.pytorch.org/whl/torch_stable.html -torchvision==0.15.1; (platform_machine=="aarch64" and sys.platform=="linux") -torchvision==0.17.1+cpu; ((platform_machine=="x86_64" and sys.platform=="linux") or sys.platform=="win32") +torch +torchvision pyyaml==6.0.1 requests==2.32.2 tqdm==4.66.4 diff --git a/samples/python/python_plugin/CMakeLists.txt b/samples/python/python_plugin/CMakeLists.txt index 6338ea50e..f31d0d346 100644 --- a/samples/python/python_plugin/CMakeLists.txt +++ b/samples/python/python_plugin/CMakeLists.txt @@ -39,6 +39,90 @@ if(NOT MSVC) set_ifndef(TRT_INCLUDE /usr/include/x86_64-linux-gnu) set_ifndef(CUDA_INC_DIR /usr/local/cuda/include) set_ifndef(CUDA_LIB_DIR /usr/local/cuda) + + find_program(NVCC_EXECUTABLE nvcc HINTS "${CUDA_LIB_DIR}/bin") + + # extract CUDA version + if(NVCC_EXECUTABLE) + execute_process( + COMMAND "${NVCC_EXECUTABLE}" --version + OUTPUT_VARIABLE NVCC_VERSION_OUTPUT + ERROR_VARIABLE NVCC_VERSION_ERROR + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + # Parse the version number from the output + string(REGEX MATCH "release ([0-9]+)\\.([0-9]+)" CUDA_VERSION_MATCH "${NVCC_VERSION_OUTPUT}") + if(CUDA_VERSION_MATCH) + set(CUDA_VERSION_MAJOR "${CMAKE_MATCH_1}") + set(CUDA_VERSION_MINOR "${CMAKE_MATCH_2}") + set(CUDA_VER "${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR}") + else() + message(FATAL_ERROR "Could not parse CUDA version from nvcc output.") + endif() + else() + message(FATAL_ERROR "nvcc not found in ${CUDA_INST_DIR}/bin") + endif() + + # Function to check if the current CUDA version is greater than or equal to a specified version + function(cuda_ge major minor result_var) + set(VERSION_TO_COMPARE "${major}.${minor}") + if(CUDA_VER VERSION_GREATER_EQUAL "${VERSION_TO_COMPARE}") + set(${result_var} 1 PARENT_SCOPE) + else() + set(${result_var} 0 PARENT_SCOPE) + endif() + endfunction() + + # Loop through minor versions from 0 to 9 + foreach(minor RANGE 0 9) + set(result_var "CUDA_GE_11_${minor}") + cuda_ge(11 ${minor} ${result_var}) + endforeach() + + set(SAMPLE_SMS "75") + + if(CUDA_GE_11_0) + list(APPEND SAMPLE_SMS "80") + endif() + + if(CUDA_GE_11_1) + list(APPEND SAMPLE_SMS "86") + endif() + + if(CUDA_GE_11_4) + list(APPEND SAMPLE_SMS "87") + endif() + + if(CUDA_GE_11_8) + list(APPEND SAMPLE_SMS "89" "90") + endif() + + set(NON_HFC_SMS "89" "90") + + if(NOT DEFINED GENCODES) + set(GENCODES "") + + # Add -gencode flags for each SM in SAMPLE_SMS + foreach(sm ${SAMPLE_SMS}) + list(APPEND GENCODES "-gencode=arch=compute_${sm},code=sm_${sm}") + endforeach() + + # Filter out NON_HFC_SMS from SAMPLE_SMS to get HFC_SMS + set(HFC_SMS ${SAMPLE_SMS}) + foreach(sm ${NON_HFC_SMS}) + list(REMOVE_ITEM HFC_SMS "${sm}") + endforeach() + + # Get the highest supported forward compatible SM + if(HFC_SMS) + list(SORT HFC_SMS) + list(GET HFC_SMS -1 GEN_PTX_SM) + # Add PTX generation flag + list(APPEND GENCODES "-gencode=arch=compute_${GEN_PTX_SM},code=compute_${GEN_PTX_SM}") + else() + message(WARNING "No hardware forward compatible SMs found. PTX generation skipped.") + endif() + endif() endif() message("\nThe following variables are derived from the values of the previous variables unless provided explicitly:\n") @@ -49,11 +133,13 @@ set_ifndef(NVINFER_LIB ${_NVINFER_LIB}) find_library(_CUDA_LIB cuda HINTS ${CUDA_LIB_DIR} PATH_SUFFIXES lib/stubs lib64/stubs) set_ifndef(CUDA_LIB ${_CUDA_LIB}) + # -------- BUILDING -------- add_library(circ_pad_plugin SHARED ${CMAKE_SOURCE_DIR}/circ_plugin_cpp/circ_pad_plugin.cu ) +target_compile_options(circ_pad_plugin PRIVATE ${GENCODES}) target_include_directories(circ_pad_plugin PUBLIC ${CUDA_INC_DIR} diff --git a/samples/python/python_plugin/requirements.txt b/samples/python/python_plugin/requirements.txt index 0299b6b0b..7c11ebb8b 100644 --- a/samples/python/python_plugin/requirements.txt +++ b/samples/python/python_plugin/requirements.txt @@ -1,5 +1,5 @@ cuda-python==12.2.0; python_version <= "3.10" -cuda-python==12.5.0; python_version >= "3.11" +cuda-python==12.6.0; python_version >= "3.11" cupy-cuda12x numba triton; platform_system != "Windows" diff --git a/samples/python/quickly_deployable_plugins/README.md b/samples/python/quickly_deployable_plugins/README.md new file mode 100644 index 000000000..06533248a --- /dev/null +++ b/samples/python/quickly_deployable_plugins/README.md @@ -0,0 +1,221 @@ +# Quickly Deployable TRT Python Plugins [Experimental in TensorRT 10.6] + +This is a sample to showcase quickly deployable Python-based plugin definitions (QDPs) in TRT. QDPs are able to support a large majority of use cases for adding custom operators to TRT, and will be the recommended option when it becomes a stable feature in TRT 10.7. + +## Introduction + +While the regular TRT plugin interfaces are powerful in the flexibility and tunability they provide, for the vast majority of use cases, users will benefit from the simplicity offered by the QDP workflow. + - The `tensorrt.plugin` module provides many intuitive APIs that drastically reduces the amount of boilerplate required to implement a plugin + - The concept of plugin registration, plugin creators and the plugin registry is abstracted away + - The stateless nature of QDPs all but eliminates the complications of having to comply with a predefined plugin lifecycle + +This sample contains several mini-samples that demonstrate a few common use cases. + +## Setting Up The Environment + +To build and install the bindings, follow the instructions in `$TRT_OSSPATH/python/README.md`. + +Then install the requisite packages +```bash +cd $TRT_OSSPATH/samples/python/quickly_deployable_plugins +pip3 install -r requirements.txt +``` + +# Implementing a quickly deployable Python plugin + +QDP definitions consist of a set of decorated functions that define properties and behaviors of the plugin. + - `@tensorrt.plugin.register`: Returns shape and type characteristics of output tensors, and any attributes the plugin needs to function. + - `@tensorrt.plugin.impl`: Performs the plugin computation + - (Optional) `@tensorrt.plugin.autotune`: Defines the different data types and formats (tensor layouts) supported by the plugin's IO and any tactics supported by the plugin. Defining this function allows TensorRT to "tune" the plugin during the engine build to find the most performant type/format and tactic combination on the target system. + +The specifics of these functions will become clear through the following mini-samples. + +# A Simple Plugin: Elementwise-Add + +This mini-sample contains an elementwise addition plugin, where the computation is being performed with an OpenAI Triton kernel. Let's first take a look at the `tensorrt.plugin.register` function. + +```python +import tensorrt.plugin as trtp + +@trtp.register("sample::elemwise_add_plugin") +def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> trtp.TensorDesc: + return inp0.like() +``` + +The argument "sample::elemwise_add_plugin" defines the namespace ("sample") and name ("elemwise_add_plugin") of the plugin. Input arguments to the decorated function (`plugin_desc`) annotated with `trt.plugin.TensorDesc` denote the input tensors; all others are interpreted as plugin attributes (see the [TRT API Reference](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/tensorrt.plugin/trt_plugin_register.html) for a full list of allowed attribute types). The output signature is a `trt.plugin.TensorDesc` describing the output. `inp0.like()` returns a tensor descriptor with identical shape and type characteristics to `inp0`. + +The computation function, decorated with `trt.plugin.impl`, receives `trt.plugin.Tensor`s for each input and output. In contrast to `TensorDesc`s, a `Tensor` references an underlying data buffer, directly accessible through `Tensor.data_ptr`. When working with Torch and OpenAI Triton kernels, it is easier to use `torch.as_tensor()` to zero-copy construct a `torch.Tensor` corresponding to the `trt.plugin.Tensor`. + +This sample also showcases the effect of omitting/defining a `trt.plugin.autotune` function, which must return a list of `trt.plugin.AutoTuneCombination`s. In this case, we define a single combination `AutoTuneCombination("FP32|FP16, FP32|FP16")`; this indicates that the input and output must be either both FP32 or both FP16. See the TRT API Reference for a detailed description of the grammar underlying `AutoTuneCombination`s. + +## Running the sample + +```bash +python3 qdp_runner.py add [--autotune] [-v] +``` + +`--autotune` simulates having defined a `trt.plugin.autotune` function. Enabling verbose logging (`-v`) is recommended to see the effect of autotuning. It can be observed that the `trt.plugin.impl` function is invoked several times during the engine build process when autotune is enabled. With autotuning turned off, `trt.plugin.impl` is invoked only once (when inference is run after building the engine). + +```bash +$ python3 qdp_runner.py add --autotune -v +... +Executing for inp0.dtype=DataType.FLOAT and output[0].dtype=DataType.FLOAT +Executing for inp0.dtype=DataType.FLOAT and output[0].dtype=DataType.FLOAT +Executing for inp0.dtype=DataType.FLOAT and output[0].dtype=DataType.FLOAT +Executing for inp0.dtype=DataType.FLOAT and output[0].dtype=DataType.FLOAT +Executing for inp0.dtype=DataType.HALF and output[0].dtype=DataType.HALF +Executing for inp0.dtype=DataType.HALF and output[0].dtype=DataType.HALF +Executing for inp0.dtype=DataType.HALF and output[0].dtype=DataType.HALF +Executing for inp0.dtype=DataType.HALF and output[0].dtype=DataType.HALF +[I] Finished engine building in 1.073 seconds +Executing for inp0.dtype=DataType.HALF and output[0].dtype=DataType.HALF +``` + +# Implementing in-place custom ops with I/O aliasing + +In-place computations can be accomplished with TRT plugins via aliased I/O. i.e. An input that needs to be modified in-place can be represented by an input-output pair, where the output is aliased to the input. For example, if in-place addition is needed (instead of the out-of-place addition of the above sample), that can be achieved as below: +```python +import tensorrt.plugin as trtp + +@trtp.register("sample::elemwise_add_plugin_") +def add_plugin_desc_(inp0: trtp.TensorDesc) -> trtp.TensorDesc: + return inp0.aliased() +``` + +Note the use of `trt.plugin.TensorDesc.aliased()` to produce an output `TensorDesc` that is aliased to `inp0`. + +To appreciate the effect of aliasing better, this sample adds two in-place add plugins chained together. + +## Running the sample + +Enabling verbose logging (`-v`) is recommended to see the effect of autotuning, which is always enabled. + +```bash +python3 qdp_runner.py inplace_add [--autotune] [-v] +``` + +# An op with data-dependent output shapes: Non-zero + +Non-zero is an operation where the indices of the non-zero elements of the input tensor is found -- it has data-dependent output shapes (DDS). As such, typical shape calculations cannot be done with input shapes. + +To handle DDS, the extent of each data-dependent output dimension must be expressed in terms of a *_size tensor_*, which is a scalar that communicates to TRT an upper-bound and an autotune value for that dimension, in terms of the input shapes. The TRT engine build may be optimized for the autotune value, but the extent of that dimension may stretch up to the upper-bound at runtime. + +In this sample, we consider a 2D input tensor `inp0`; the output will be an $N x 2$ tensor (a set of $N$ 2D indices), where $N$ is the number of non-zero indices. At maximum, all elements could be non-zero, and so the upper-bound could be expressed as `upper_bound = inp0.shape_expr[0] * inp0.shape_expr[1]`. Note that `trt.plugin.TensorDesc.shape_expr` returns symbolic shape expressions for that tensor. Arithmetic operations on shape expressions are supported through standard Python binary operators (see [TRT Python API reference](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/tensorrt.plugin/Shape/ShapeExpr.html) for full list of supported operations). + +On average, we can expect half of the input to be filled with zero, so a size tensor can be constructed with that as the autotune value: +```python +st = trtp.size_tensor(opt = upper_bound // 2, upper_bound = upper_bound) +``` + +Now we're ready to construct the output shape. `st.expr()` returns a shape expression for the size tensor, so a tensor descriptor for the output shape can be constructed as `trt.plugin.from_shape_expr((st.expr(), 2), dtype=trt.int32)`. TRT requires that any size tensors also be made outputs of the plugin. Putting things together, we arrive at the following: + +```python +import tensorrt.plugin as trtp + +@trtp.register("sample::non_zero_plugin") +def non_zero_plugin_reg( + inp0: trtp.TensorDesc, +) -> Tuple[trtp.TensorDesc, trtp.TensorDesc]: + upper_bound = inp0.shape_expr[0] * inp0.shape_expr[1] + st = trtp.size_tensor(upper_bound // 2, upper_bound) + return trtp.from_shape_expr((st.expr(), 2), dtype=trt.int32), st +``` + +## Running the sample + +Enabling verbose logging (`-v`) is recommended to see the effect of autotuning, which is always enabled. + +```bash +python3 qdp_runner.py non_zero [-v] +``` + +# Using multiple tactics and ONNX: Cirular padding + +This sample contains a circular padding plugin, which is useful for ops like circular convolution. + +## ONNX model with a plugin + +It is often useful to run an ONNX node with a custom op through a TRT plugin that you have written. To allow the TRT ONNX parser to correctly recognize your plugin as being mapped to an ONNX node, ensure that + - The `op` property of the node is exactly the same as your plugin name. + - The node contains a string attribute called "plugin_namespace" with the namespace of your plugin. + +In this sample, we define a plugin with the ID "sample::circ_pad_plugin", so if using ONNX Graphsurgeon, the custom op node can be constructed as follows: + +```python +import onnx_graphsurgeon as gs + +var_x = gs.Variable(name="x", shape=inp_shape, dtype=np.float32) +var_y = gs.Variable(name="y", dtype=np.float32) + +circ_pad_node = gs.Node( + name="circ_pad_plugin", + op="circ_pad_plugin", + inputs=[var_x], + outputs=[var_y], + attrs={"pads": pads, "plugin_namespace": "sample"}, +) +``` + +## Multiple tactics + +Sometimes, you may have multiple kernels (or backends) that can be used to perform the computation of the plugin -- these are typically called *_tactics_*. If it cannot be predetermined which of these tactics may perform the fastest, it is possible to let TRT time the plugin for each tactic and determine which one is fastest. + +Communicating the availability of multiple tactics can simply be done through the `trt.plugin.autotune` function. +```python +import tensorrt.plugin as trtp +from enum import IntEnum + +class Tactic(IntEnum): + TORCH = 1 + TRITON = 2 + +@trt.plugin.autotune("sample::circ_pad_plugin") +def circ_pad_plugin_autotune(inp0: trtp.TensorDesc, pads: npt.NDArray[np.int32], outputs: Tuple[trtp.TensorDesc]) -> List[trtp.AutoTuneCombination]: + c = trtp.AutoTuneCombination() + c.pos([0, 1], "FP32|FP16") + c.tactics([int(Tactic.TORCH), int(Tactic.TRITON)]) + return [c] +``` + +Note that we're using another way of constructing a `trt.plugin.AutoTuneCombination` here -- namely, through `pos(...)` to populate the type/format information and `tactics(...)` to specify the tactics. In this sample, we use an OpenAI Triton kernel and `torch.nn.functional.pad` as two methods to compute the circular padding. + +## Loading and running a TRT engine containing a plugin + +If you have a TRT engine built with a plugin, executing that engine only requires the plugin definitions for `trt.plugin.register` and `trt.plugin.impl` to be available in the module where the engine is being deserialized (note: the `trt.plugin.autotune` definition is not required to be present). + +To simulate the loading of an engine, first run this sample with the `--save_engine` flag, followed by `--artifacts_dir [dir]` with a directory in which you wish the engine to be saved. Then run the sample again with `--load engine` and `--artifacts_dir` set to the same directory. + +## Running the sample + +```bash +python3 qdp_runner.py circ_pad [--multi_tactic] [--save_engine] [--load_engine] --mode {onnx,inetdef} [--artifacts_dir ARTIFACTS_DIR] [-v] + +options: + --multi_tactic Enable multiple tactics. + --save_engine Save engine to the artifacts_dir. + --load_engine Load engine from the artifacts_dir. Ignores all other options. + --artifacts_dir ARTIFACTS_DIR + Whether to store (or retrieve) artifacts. + --mode {onnx,inetdef} Whether to use ONNX parser or INetworkDefinition APIs to construct the network. + -v, --verbose Enable verbose log output. +``` + +# Additional resources + +**`tensorrt.plugin` API reference** +- [`tensorrt.plugin` module API reference](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/tensorrt.plugin/index.html) + +**Guide to TensorRT plugins** +- [Extending TensorRT with Custom Layers](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#extending) + +# License + +For terms and conditions for use, reproduction, and distribution, see the [TensorRT Software License Agreement](https://docs.nvidia.com/deeplearning/sdk/tensorrt-sla/index.html) documentation. + +# Changelog + +October 2024: Initial release of this sample + +# Known issues + +There are no known issues in this sample diff --git a/samples/python/quickly_deployable_plugins/oait_kernels.py b/samples/python/quickly_deployable_plugins/oait_kernels.py new file mode 100644 index 000000000..fa6ecfe87 --- /dev/null +++ b/samples/python/quickly_deployable_plugins/oait_kernels.py @@ -0,0 +1,74 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import triton +import triton.language as tl + +@triton.jit +def add_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + tl.store(y_ptr + offsets, x + 1, mask=mask) + + +@triton.jit +def circ_pad( + X, + all_pads_0, + all_pads_2, + all_pads_4, + all_pads_6, + orig_dims_0, + orig_dims_1, + orig_dims_2, + orig_dims_3, + Y, + Y_shape_1, + Y_shape_2, + Y_shape_3, + X_len, + Y_len, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + i = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + mask_y = i < Y_len + + i3 = i % Y_shape_3 + i2 = (i // Y_shape_3) % Y_shape_2 + i1 = (i // Y_shape_3 // Y_shape_2) % Y_shape_1 + i0 = i // Y_shape_3 // Y_shape_2 // Y_shape_1 + + j0 = (i0 - all_pads_0 + orig_dims_0) % orig_dims_0 + j1 = (i1 - all_pads_2 + orig_dims_1) % orig_dims_1 + j2 = (i2 - all_pads_4 + orig_dims_2) % orig_dims_2 + j3 = (i3 - all_pads_6 + orig_dims_3) % orig_dims_3 + + load_idx = ( + orig_dims_3 * orig_dims_2 * orig_dims_1 * j0 + + orig_dims_3 * orig_dims_2 * j1 + + orig_dims_3 * j2 + + j3 + ) + mask_x = load_idx < X_len + + x = tl.load(X + load_idx, mask=mask_x) + + tl.store(Y + i, x, mask=mask_y) diff --git a/samples/python/quickly_deployable_plugins/qdp_defs.py b/samples/python/quickly_deployable_plugins/qdp_defs.py new file mode 100644 index 000000000..19f60a276 --- /dev/null +++ b/samples/python/quickly_deployable_plugins/qdp_defs.py @@ -0,0 +1,248 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorrt as trt +import torch +import numpy as np + +from typing import Tuple, List + +import tensorrt.plugin as trtp +import numpy.typing as npt + +import logging + +logging.basicConfig(level=logging.INFO) +logging.getLogger("QuicklyDeployablePlugins").setLevel(logging.INFO) + +########## Elemwise-add plugin definition ########## + + +@trtp.register("sample::elemwise_add_plugin") +def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> trtp.TensorDesc: + return inp0.like() + + +# Helper to simulate defining/omitting an autotune definition for the plugin +def register_autotune(): + # Type annotations can be omitted for autotune and impl definitions, but will be checked for consistency if added + @trtp.autotune("sample::elemwise_add_plugin") + def add_plugin_autotune( + inp0: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc] + ) -> List[trtp.AutoTuneCombination]: + return [trtp.AutoTuneCombination("FP32|FP16, FP32|FP16")] + + +@trtp.impl("sample::elemwise_add_plugin") +def add_plugin_impl( + inp0: trtp.Tensor, block_size: int, outputs: Tuple[trtp.Tensor], stream: int +) -> None: + + log = logging.getLogger("QuicklyDeployablePlugins") + log.debug( + f"Executing for inp0: dtype={inp0.dtype},format={inp0.format} and output[0]: dtype={outputs[0].dtype},format={outputs[0].format}" + ) + + n = inp0.numel() + inp0_t = torch.as_tensor(inp0, device="cuda") + out_t = torch.as_tensor(outputs[0], device="cuda") + + import triton + from oait_kernels import add_kernel + + add_kernel[(triton.cdiv(n, block_size),)](inp0_t, out_t, n, BLOCK_SIZE=block_size) + + +########## In-place elemwise-add plugin definition ########## + + +@trtp.register("sample::elemwise_add_plugin_") +def add_plugin_desc_(inp0: trtp.TensorDesc, delta: int) -> trtp.TensorDesc: + return inp0.aliased() + + +@trtp.autotune("sample::elemwise_add_plugin_") +def add_plugin_autotune_(inp0, outputs) -> List[trtp.AutoTuneCombination]: + return [ + trtp.AutoTuneCombination("FP32, FP32", "LINEAR*HWC"), + trtp.AutoTuneCombination("FP32|FP16, FP32|FP16", "LINEAR"), + ] + + +@trtp.impl("sample::elemwise_add_plugin_") +def add_plugin_impl_(inp0, delta: int, outputs, stream) -> None: + + log = logging.getLogger("QuicklyDeployablePlugins") + log.debug( + f"Executing for inp0: dtype={inp0.dtype},format={inp0.format} and output[0]: dtype={outputs[0].dtype},format={outputs[0].format}" + ) + + inp0_t = torch.as_tensor(inp0, device="cuda") + inp0_t.add_(delta) + + +########## Non-zero plugin (DDS) ########## + + +@trtp.register("sample::non_zero_plugin") +def non_zero_plugin_reg( + inp0: trtp.TensorDesc, +) -> Tuple[trtp.TensorDesc, trtp.TensorDesc]: + upper_bound = inp0.shape_expr[0] * inp0.shape_expr[1] + st = trtp.size_tensor(upper_bound // 2, upper_bound) + return trtp.from_shape_expr((st.expr(), 2), dtype=trt.int32), st + + +@trtp.autotune("sample::non_zero_plugin") +def non_zero_plugin_autotune(inp0, outputs) -> List[trtp.AutoTuneCombination]: + return [trtp.AutoTuneCombination("FP32|FP16, INT32, INT32")] + + +@trtp.impl("sample::non_zero_plugin") +def non_zero_plugin_impl(inp0, outputs, stream) -> None: + + log = logging.getLogger("QuicklyDeployablePlugins") + log.debug( + f"Executing for inp0: dtype={inp0.dtype},format={inp0.format} and output[0]: dtype={outputs[0].dtype},format={outputs[0].format}" + ) + + inp0_t = torch.as_tensor(inp0, device="cuda") + out_1 = torch.as_tensor(outputs[1], device="cuda").reshape((-1,)) + + out = torch.nonzero(inp0_t) + + out0 = torch.as_tensor(outputs[0].aliased(out.shape), device="cuda") + out0.copy_(out) + out_1.copy_(torch.Tensor([out.shape[0]])) + + +########## Circular padding plugin ######## + + +@trtp.register("sample::circ_pad_plugin") +def circ_pad_plugin_desc( + inp0: trtp.TensorDesc, pads: npt.NDArray[np.int32] +) -> trtp.TensorDesc: + ndim = inp0.ndim + out_desc = inp0.like() + + for i in range(np.size(pads) // 2): + out_desc.shape_expr[ndim - i - 1] += int(pads[i * 2] + pads[i * 2 + 1]) + + return out_desc + + +# Helper to define a multi-tactic implementation of the plugin +def enable_multi_tactic_circ_pad(): + + from enum import IntEnum + + class Tactic(IntEnum): + TORCH = 1 + TRITON = 2 + + @trtp.autotune("sample::circ_pad_plugin") + def circ_pad_plugin_autotune( + inp0: trtp.TensorDesc, + outputs: Tuple[trtp.TensorDesc], + ) -> List[trtp.AutoTuneCombination]: + c = trtp.AutoTuneCombination() + c.pos([0, 1], "FP32|FP16") + c.tactics([int(Tactic.TORCH), int(Tactic.TRITON)]) + return [c] + + @trtp.impl("sample::circ_pad_plugin") + def circ_pad_plugin_impl( + inp0: trtp.Tensor, + pads: npt.NDArray[np.int32], + outputs: Tuple[trtp.Tensor], + stream: int, + tactic: int, + ) -> None: + + log = logging.getLogger("QuicklyDeployablePlugins") + log.debug( + f"Executing for inp0: dtype={inp0.dtype},format={inp0.format} and output[0]: dtype={outputs[0].dtype},format={outputs[0].format}" + ) + + inp_t = torch.as_tensor(inp0, device="cuda") + out_t = torch.as_tensor(outputs[0], device="cuda") + + if tactic == Tactic.TORCH: + out = torch.nn.functional.pad(inp_t, pads.tolist(), mode="circular") + out_t.copy_(out) + elif tactic == Tactic.TRITON: + N = inp0.ndim + all_pads = np.zeros((N * 2,), dtype=np.int32) + out_dims = trtp.Shape(tuple(inp0.shape)) + + for i in range(np.size(pads) // 2): + out_dims[N - i - 1] += pads[i * 2] + pads[i * 2 + 1] + all_pads[N * 2 - 2 * i - 2] = pads[i * 2] + all_pads[N * 2 - 2 * i - 1] = pads[i * 2 + 1] + + all_pads = all_pads.tolist() + + block_size = 256 + num_blocks = tuple( + [int((np.prod(out_dims) + block_size - 1) // block_size)] + ) + + from oait_kernels import circ_pad + + circ_pad[num_blocks]( + inp_t, + all_pads[0], + all_pads[2], + all_pads[4], + all_pads[6], + inp0.shape[0], + inp0.shape[1], + inp0.shape[2], + inp0.shape[3], + out_t, + int(out_dims[1]), + int(out_dims[2]), + int(out_dims[3]), + inp0.numel(), + out_dims.numel(), + BLOCK_SIZE=block_size, + ) + + +# Helper to define a single tactic implementation of the plugin +def enable_single_tactic_circ_pad(): + @trtp.autotune("sample::circ_pad_plugin") + def circ_pad_plugin_autotune( + inp0: trtp.TensorDesc, + outputs: Tuple[trtp.TensorDesc], + ) -> List[trtp.AutoTuneCombination]: + + return [trtp.AutoTuneCombination("FP32|FP16, FP32|FP16")] + + @trtp.impl("sample::circ_pad_plugin") + def circ_pad_plugin_impl( + inp0: trtp.Tensor, + pads: npt.NDArray[np.int32], + outputs: Tuple[trtp.Tensor], + stream: int, + ) -> None: + inp_t = torch.as_tensor(inp0, device="cuda") + out_t = torch.as_tensor(outputs[0], device="cuda") + + out = torch.nn.functional.pad(inp_t, pads.tolist(), mode="circular") + out_t.copy_(out) diff --git a/samples/python/quickly_deployable_plugins/qdp_runner.py b/samples/python/quickly_deployable_plugins/qdp_runner.py new file mode 100644 index 000000000..f2949ef5f --- /dev/null +++ b/samples/python/quickly_deployable_plugins/qdp_runner.py @@ -0,0 +1,359 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorrt as trt +import torch +import numpy as np + +from polygraphy.backend.trt import ( + CreateConfig, + TrtRunner, + create_network, + engine_from_network, + network_from_onnx_path, + bytes_from_engine, + engine_from_bytes, +) + +from polygraphy.backend.common import bytes_from_path +from polygraphy import cuda + +import onnx_graphsurgeon as gs +import onnx +import os +import argparse + +import tensorrt.plugin as trtp + +import qdp_defs +import logging + +def run_add(enable_autotune=False): + + if enable_autotune: + qdp_defs.register_autotune() + + BLOCK_SIZE = 256 + + builder, network = create_network() + x = torch.randint(10, (10, 3, 32, 32), dtype=torch.float32, device="cuda") + + # Populate network + i_x = network.add_input(name="x", dtype=trt.DataType.FLOAT, shape=x.shape) + + out = network.add_plugin( + trtp.op.sample.elemwise_add_plugin(i_x, block_size=BLOCK_SIZE) + ) + out.get_output(0).name = "y" + network.mark_output(tensor=out.get_output(0)) + + builder.create_builder_config() + + engine = engine_from_network( + (builder, network), + CreateConfig(fp16=True), + ) + + with TrtRunner(engine, "trt_runner") as runner: + outputs = runner.infer( + { + "x": x, + }, + copy_outputs_to_host=False, + ) + + if torch.allclose(x + 1, outputs["y"]): + print("Inference result is correct!") + else: + print("Inference result is incorrect!") + + +def run_inplace_add(): + builder, network = create_network() + x = torch.ones((10, 3, 32, 32), dtype=torch.float32, device="cuda") + + x_clone = x.clone() + + i_x = network.add_input(name="x", dtype=trt.DataType.FLOAT, shape=x.shape) + + # Amounts to elementwise-add in the first and second plugins + deltas = (2, 4) + + out0 = network.add_plugin(trtp.op.sample.elemwise_add_plugin_(i_x, delta=deltas[0])) + out1 = network.add_plugin( + trtp.op.sample.elemwise_add_plugin_(out0.get_output(0), delta=deltas[1]) + ) + out1.get_output(0).name = "y" + network.mark_output(tensor=out1.get_output(0)) + + builder.create_builder_config() + + # Enable preview feature for aliasing plugin I/O + config = CreateConfig( + fp16=True, preview_features=[trt.PreviewFeature.ALIASED_PLUGIN_IO_10_03] + ) + + engine = engine_from_network( + (builder, network), + config, + ) + + context = engine.create_execution_context() + + stream = cuda.Stream() + + context.set_tensor_address("x", x.data_ptr()) + context.set_tensor_address("y", x.data_ptr()) + context.execute_async_v3(stream.ptr) + stream.synchronize() + + if torch.allclose(x, x_clone + sum(deltas), atol=1e-2): + print("Inference result is correct!") + else: + print("Inference result is incorrect!") + print(x[0][0][0][:10]) + print(x_clone[0][0][0][:10]) + + +def run_non_zero(): + builder, network = create_network() + inp_shape = (128, 128) + + X = np.random.normal(size=inp_shape).astype(trt.nptype(trt.DataType.FLOAT)) + + # Zero out some random indices + indices = np.random.choice( + np.prod(inp_shape), + replace=False, + size=np.random.randint(0, np.prod(inp_shape) + 1), + ) + X[np.unravel_index(indices, inp_shape)] = 0 + + # Populate network + i_x = network.add_input(name="X", dtype=trt.DataType.FLOAT, shape=inp_shape) + + out = network.add_plugin(trtp.op.sample.non_zero_plugin(i_x)) + out.get_output(0).name = "Y" + network.mark_output(tensor=out.get_output(0)) + + builder.create_builder_config() + + engine = engine_from_network( + (builder, network), + config=CreateConfig(fp16=True), + ) + + Y_ref = np.transpose(np.nonzero(X)) + + with TrtRunner(engine, "trt_runner") as runner: + outputs = runner.infer({"X": X}) + Y = outputs["Y"] + Y = Y[np.lexsort(np.fliplr(Y).T)] + + if np.allclose(Y, Y_ref, atol=1e-3): + print("Inference result is correct!") + else: + print("Inference result is incorrect!") + + +def check_artifacts_dir_exists(artifacts_dir): + if not os.path.exists(artifacts_dir): + raise ValueError(f"artifacts_dir '{artifacts_dir}' does not exist") + + +def run_circ_pad( + enable_multi_tactic=False, mode="onnx", artifacts_dir=None, save_or_load_engine=None +): + + if enable_multi_tactic: + qdp_defs.enable_multi_tactic_circ_pad() + else: + qdp_defs.enable_single_tactic_circ_pad() + + inp_shape = (10, 3, 32, 32) + x = np.random.normal(size=inp_shape).astype(trt.nptype(trt.DataType.FLOAT)) + + pads = np.array((1, 1, 1, 1), dtype=np.int32) + + if save_or_load_engine is not None and save_or_load_engine is False: + check_artifacts_dir_exists(artifacts_dir) + engine_path = os.path.join(artifacts_dir, "circ_pad.engine") + engine = engine_from_bytes(bytes_from_path(engine_path)) + else: + if mode == "inetdef": + builder, network = create_network() + i_x = network.add_input(name="x", dtype=trt.DataType.FLOAT, shape=x.shape) + out = network.add_plugin(trtp.op.sample.circ_pad_plugin(i_x, pads=pads)) + out.get_output(0).name = "y" + network.mark_output(tensor=out.get_output(0)) + + engine = engine_from_network( + (builder, network), + CreateConfig(fp16=True), + ) + elif mode == "onnx": + if artifacts_dir is None: + raise ValueError("'artifacts_dir' must be specified in onnx mode") + + check_artifacts_dir_exists(artifacts_dir) + + onnx_path = os.path.join(artifacts_dir, "circ_pad.onnx") + var_x = gs.Variable(name="x", shape=inp_shape, dtype=np.float32) + var_y = gs.Variable(name="y", dtype=np.float32) + circ_pad_node = gs.Node( + name="circ_pad_plugin", + op="circ_pad_plugin", + inputs=[var_x], + outputs=[var_y], + attrs={"pads": pads, "plugin_namespace": "sample"}, + ) + graph = gs.Graph( + nodes=[circ_pad_node], inputs=[var_x], outputs=[var_y], opset=16 + ) + onnx.save(gs.export_onnx(graph), onnx_path) + + engine = engine_from_network( + network_from_onnx_path(onnx_path), CreateConfig(fp16=True) + ) + else: + raise ValueError(f"Unknown mode {mode}") + + if save_or_load_engine is not None and save_or_load_engine is True: + check_artifacts_dir_exists(artifacts_dir) + engine_path = os.path.join(artifacts_dir, "circ_pad.engine") + with open(engine_path, "wb") as f: + f.write(bytes_from_engine(engine)) + + Y_ref = np.pad(x, [[0, 0], [0, 0], [pads[0], pads[1]], [pads[2], pads[3]]], "wrap") + + with TrtRunner(engine, "trt_runner") as runner: + outputs = runner.infer({"x": x}) + Y = outputs["y"] + + if np.allclose(Y, Y_ref, atol=1e-2): + print("Inference result is correct!") + else: + print("Inference result is incorrect!") + + +def setup_add_sample(subparsers): + subparser = subparsers.add_parser("add", help="'add' sample help") + subparser.add_argument("--autotune", action="store_true", help="Enable autotuning") + subparser.add_argument( + "-v", "--verbose", action="store_true", help="Enable more verbose log output" + ) + + +def setup_inplace_add_sample(subparsers): + subparser = subparsers.add_parser("inplace_add", help="inplace_add sample help") + subparser.add_argument( + "-v", "--verbose", action="store_true", help="Enable more verbose log output" + ) + + +def setup_non_zero_sample(subparsers): + subparser = subparsers.add_parser("non_zero", help="non_zero sample help") + subparser.add_argument( + "-v", "--verbose", action="store_true", help="Enable more verbose log output" + ) + + +def setup_circ_pad_sample(subparsers): + subparser = subparsers.add_parser("circ_pad", help="circ_pad sample help") + subparser.add_argument( + "--multi_tactic", action="store_true", help="Enable multiple tactics" + ) + subparser.add_argument( + "--save_engine", action="store_true", help="Save engine to the artifacts_dir" + ) + subparser.add_argument( + "--load_engine", + action="store_true", + help="Load engine from the artifacts_dir. Ignores all other options.", + ) + subparser.add_argument( + "--artifacts_dir", + type=str, + help="Whether to store (or retrieve) artifacts.", + ) + subparser.add_argument( + "--mode", + type=str, + choices=["onnx", "inetdef"], + help="Whether to use ONNX parser or INetworkDefinition APIs to construct the network.", + ) + subparser.add_argument( + "-v", "--verbose", action="store_true", help="Enable verbose log output" + ) + + return subparser + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser = argparse.ArgumentParser(description="Main script help") + subparsers = parser.add_subparsers(dest="sample", help="Mode help", required=True) + + setup_add_sample(subparsers) + setup_inplace_add_sample(subparsers) + circ_pad_subparser = setup_circ_pad_sample(subparsers) + setup_non_zero_sample(subparsers) + + args = parser.parse_args() + + if args.verbose: + logging.getLogger("QuicklyDeployablePlugins").setLevel(logging.DEBUG) + + if args.sample == "add": + run_add(args.autotune) + if args.sample == "inplace_add": + run_inplace_add() + if args.sample == "non_zero": + run_non_zero() + if args.sample == "circ_pad": + if args.mode == "onnx": + if args.artifacts_dir is None: + parser.error( + "circ_pad: argument --mode: When mode is 'onnx', artifacts_dir is required" + ) + + save_or_load_engine = None + + if args.load_engine is True: + if args.save_engine is True: + parser.error( + "circ_pad: save_engine and load_engine cannot be specified at the same time. First save_engine and load_engine separately." + ) + else: + if args.multi_tactic is True or args.mode is not None: + print( + "warning circ_pad: when load_engine is specified, all other options except 'artifacts_dir' is ignored." + ) + + save_or_load_engine = False + else: + if args.mode is None: + circ_pad_subparser.print_help() + parser.error( + "circ_pad: '--mode' option is required." + ) + + if args.save_engine is True: + save_or_load_engine = True + + run_circ_pad(args.multi_tactic, args.mode, args.artifacts_dir, save_or_load_engine) diff --git a/samples/python/quickly_deployable_plugins/requirements.txt b/samples/python/quickly_deployable_plugins/requirements.txt new file mode 100644 index 000000000..1b40b0c26 --- /dev/null +++ b/samples/python/quickly_deployable_plugins/requirements.txt @@ -0,0 +1,13 @@ +triton; platform_system != "Windows" +torch +--extra-index-url https://pypi.ngc.nvidia.com +polygraphy +colored +numpy==1.23.5; (platform_system != "Windows" and python_version <= "3.10") +numpy==1.26.4; (platform_system != "Windows" and python_version >= "3.11") +onnx==1.16.0; platform_system == "Windows" +--extra-index-url https://pypi.ngc.nvidia.com +onnx-graphsurgeon +pyyaml==6.0.1 +requests==2.32.2 +tqdm==4.66.4 diff --git a/samples/python/quickly_deployable_plugins/requirements.yml b/samples/python/quickly_deployable_plugins/requirements.yml new file mode 100644 index 000000000..39a435395 --- /dev/null +++ b/samples/python/quickly_deployable_plugins/requirements.yml @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +--- +args: + polygraphy: + - '--extra-index-url https://pypi.ngc.nvidia.com' +conditions: + onnx-graphsurgeon: + - onnx-graphsurgeon + onnx: + - onnx==1.16.0; platform_system == "Windows" + triton: + - triton; platform_system != "Windows" + numpy: + - 'numpy==1.23.5; (platform_system != "Windows" and python_version <= "3.10")' + - 'numpy==1.26.4; (platform_system != "Windows" and python_version >= "3.11")' +packages: + - triton + - torch + - polygraphy + - colored + - numpy + - onnx + - onnx-graphsurgeon +... diff --git a/samples/python/sample_weight_stripping/requirements.txt b/samples/python/sample_weight_stripping/requirements.txt index 01b57c060..fc537473f 100644 --- a/samples/python/sample_weight_stripping/requirements.txt +++ b/samples/python/sample_weight_stripping/requirements.txt @@ -1,6 +1,6 @@ Pillow>=10.0.0 cuda-python==12.2.0; python_version <= "3.10" -cuda-python==12.5.0; python_version >= "3.11" +cuda-python==12.6.0; python_version >= "3.11" pywin32; platform_system == "Windows" pyyaml==6.0.1 requests==2.32.2 diff --git a/samples/python/simple_progress_monitor/requirements.txt b/samples/python/simple_progress_monitor/requirements.txt index 01b57c060..fc537473f 100644 --- a/samples/python/simple_progress_monitor/requirements.txt +++ b/samples/python/simple_progress_monitor/requirements.txt @@ -1,6 +1,6 @@ Pillow>=10.0.0 cuda-python==12.2.0; python_version <= "3.10" -cuda-python==12.5.0; python_version >= "3.11" +cuda-python==12.6.0; python_version >= "3.11" pywin32; platform_system == "Windows" pyyaml==6.0.1 requests==2.32.2 diff --git a/samples/python/tensorflow_object_detection_api/requirements.txt b/samples/python/tensorflow_object_detection_api/requirements.txt index eb6d1ce33..e38c8ef9f 100644 --- a/samples/python/tensorflow_object_detection_api/requirements.txt +++ b/samples/python/tensorflow_object_detection_api/requirements.txt @@ -7,7 +7,7 @@ tf2onnx==1.15.0 pycocotools; platform_system != "Windows" pycocotools-windows; platform_system == "Windows" cuda-python==12.2.0; python_version <= "3.10" -cuda-python==12.5.0; python_version >= "3.11" +cuda-python==12.6.0; python_version >= "3.11" pywin32; platform_system == "Windows" Cython<3.0 pyyaml==6.0.1 diff --git a/samples/python/yolov3_onnx/requirements.txt b/samples/python/yolov3_onnx/requirements.txt index 9a9e9a278..32c7e45cd 100644 --- a/samples/python/yolov3_onnx/requirements.txt +++ b/samples/python/yolov3_onnx/requirements.txt @@ -1,5 +1,5 @@ cuda-python==12.2.0; python_version <= "3.10" -cuda-python==12.5.0; python_version >= "3.11" +cuda-python==12.6.0; python_version >= "3.11" pywin32; platform_system == "Windows" numpy==1.24.4; python_version <= "3.10" numpy==1.26.4; python_version >= "3.11" diff --git a/samples/trtexec/README.md b/samples/trtexec/README.md index 9b65c8e18..3d7331609 100644 --- a/samples/trtexec/README.md +++ b/samples/trtexec/README.md @@ -53,12 +53,37 @@ Compile the sample by following build instructions in [TensorRT README](https:// ### Example 1: Profiling a custom layer -You can profile a custom layer using the `IPluginRegistry` for the plugins and `trtexec`. You’ll need to first register the plugin with `IPluginRegistry`. +You can profile a custom layer, implemented as a [TensorRT plugin](https://github.com/NVIDIA/TensorRT/tree/main/plugin#tensorrt-plugins), by leveraging `trtexec`. Plugins need to be registered in the plugin registry (instance of `IPluginRegistry`) to be visible to TensorRT. `trtexec` will load the TensorRT standard plugin library (`libnvinfer_plugin.so` / `nvinfer_plugin.dll`) that provides plugin support to TensorRT. Checkout the [Non-Zero Plugins Sample](../sampleNonZeroPlugin/) for a quick sample, or the [Plugins section](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#extending) of the TensorRT Developer Guide for a more detailed walkthrough. -If you are using TensorRT shipped plugins, you should load the `libnvinfer_plugin.so` file, as these plugins are pre-registered. +Plugins can be used with `trtexec` in the following 2 ways: -If you have your own plugin, then it has to be registered explicitly. The following macro can be used to register the plugin creator `YourPluginCreator` with the `IPluginRegistry`. -`REGISTER_TENSORRT_PLUGIN(YourPluginCreator);` +
+ Using TensorRT-shipped Plugins + + +- If you are using TensorRT-shipped plugins (included in `libnvinfer_plugin.so` / `nvinfer_plugin.dll`), no extra steps are required from the user as these plugins are pre-registered with the plugin registry. +
+ +
+ Using your own Plugin + + - If you want to define your own plugin and have `trtexec` use it as part of the network, you should define your own _Plugin Shared library_ with specific entry-points recognized by TensorRT. Then, provide the shared plugin library path to `trtexec` using the `--dynamicPlugins` flag. + - More information on Plugin Shared Libraries and how to define them can be seen in the [Plugin Shared Libraries](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#plugin-serialization) section of the [TensorRT Developer Guide](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html). + + In summary, there are two methods: + 1. The `REGISTER_TENSORRT_PLUGIN` macro can be applied to the plugin creator for each plugin that needs to be statically registered. i.e. Registered at load-time of the plugin library. + 2. For dynamic registration, the plugin shared library must expose the below symbols which will be the entry points for TensorRT: + + ```cpp + extern "C" void setLoggerFinder(ILoggerFinder* finder); + extern "C" IPluginCreatorInterface* const* getCreators(int32_t& nbCreators) + ``` + In the above, `setLoggerFinder()` should accept a pointer to an `ILoggerFinder`, through which an `ILogger` instance can be retrieved for the purpose of logging inside the library code. `getCreators()` should return an array of plugin creators the library contains. Example implementations of these entry points can be found in [plugin/vc/vfcCommon.cpp](../../plugin/vc/vfcCommon.cpp) and [plugin/vc/vfcCommon.h](../../plugin/vc/vfcCommon.h). + + **Note**: Usage of `getPluginCreators` instead of `getCreators` is also valid, but deprecated. + - If the user wants to build a TensorRT engine first and run later, the user has the option to serialize the shared plugin library as part of the engine itself by specifying `--setPluginsToSerialize`. By doing so, the user does not have to specify `--dynamicPlugins` to `trtexec` when running the built engine. + - For more information on these flags, run `./trtexec --help`. +
### Example 2: Running a network on DLA diff --git a/samples/trtexec/trtexec.cpp b/samples/trtexec/trtexec.cpp index a701c149d..96b1b8e1f 100644 --- a/samples/trtexec/trtexec.cpp +++ b/samples/trtexec/trtexec.cpp @@ -272,8 +272,9 @@ int main(int argc, char** argv) { sample::setReportableSeverity(ILogger::Severity::kVERBOSE); } - +#if !TRT_WINML setCudaDevice(options.system.device, sample::gLogInfo); +#endif sample::gLogInfo << std::endl; sample::gLogInfo << "TensorRT version: " << NV_TENSORRT_MAJOR << "." << NV_TENSORRT_MINOR << "." << NV_TENSORRT_PATCH << std::endl; @@ -433,6 +434,7 @@ int main(int argc, char** argv) if (profilerEnabled && !options.inference.rerun) { iEnv->profiler.reset(new Profiler); +#if !TRT_WINML if (options.inference.graph && (getCudaDriverVersion() < 11010 || getCudaRuntimeVersion() < 11000)) { options.inference.graph = false; @@ -441,6 +443,7 @@ int main(int argc, char** argv) "and disabled CUDA graph." << std::endl; } +#endif } if (!setUpInference(*iEnv, options.inference, options.system)) @@ -486,6 +489,7 @@ int main(int argc, char** argv) iEnv->profiler.reset(profiler); iEnv->contexts.front()->setProfiler(profiler); iEnv->contexts.front()->setEnqueueEmitsProfile(false); +#if !TRT_WINML if (options.inference.graph && (getCudaDriverVersion() < 11010 || getCudaRuntimeVersion() < 11000)) { options.inference.graph = false; @@ -494,6 +498,7 @@ int main(int argc, char** argv) "and disabled CUDA graph." << std::endl; } +#endif if (!runInference(options.inference, *iEnv, options.system.device, trace)) { sample::gLogError << "Error occurred during inference" << std::endl; diff --git a/tools/Polygraphy/CHANGELOG.md b/tools/Polygraphy/CHANGELOG.md index 5196870f0..eda347fc2 100644 --- a/tools/Polygraphy/CHANGELOG.md +++ b/tools/Polygraphy/CHANGELOG.md @@ -3,6 +3,11 @@ Dates are in YYYY-MM-DD format. +## v0.49.14 (2024-09-10) +### Added +- Added `DataType.FLOAT4` for 4-bit floats (E2M1). + + ## v0.49.13 (2024-07-15) ### Added - Added option to emit logs using python `logging` module. diff --git a/tools/Polygraphy/polygraphy/__init__.py b/tools/Polygraphy/polygraphy/__init__.py index 5d3949527..9fddaeacb 100644 --- a/tools/Polygraphy/polygraphy/__init__.py +++ b/tools/Polygraphy/polygraphy/__init__.py @@ -1,3 +1,3 @@ import polygraphy.config -__version__ = "0.49.13" +__version__ = "0.49.14" diff --git a/tools/Polygraphy/polygraphy/datatype/datatype.py b/tools/Polygraphy/polygraphy/datatype/datatype.py index 22a1bb9ad..918e177fa 100644 --- a/tools/Polygraphy/polygraphy/datatype/datatype.py +++ b/tools/Polygraphy/polygraphy/datatype/datatype.py @@ -86,6 +86,7 @@ class DataType: "FLOAT64": DataTypeEntry("float64", 8, _DataTypeKind.FLOATING_POINT), "FLOAT32": DataTypeEntry("float32", 4, _DataTypeKind.FLOATING_POINT), "FLOAT16": DataTypeEntry("float16", 2, _DataTypeKind.FLOATING_POINT), + "FLOAT4": DataTypeEntry("float4", 0.5, _DataTypeKind.FLOATING_POINT), "INT16": DataTypeEntry("int16", 2, _DataTypeKind.INTEGRAL), "INT32": DataTypeEntry("int32", 4, _DataTypeKind.INTEGRAL), "INT64": DataTypeEntry("int64", 8, _DataTypeKind.INTEGRAL), diff --git a/tools/Polygraphy/polygraphy/datatype/tensorrt.py b/tools/Polygraphy/polygraphy/datatype/tensorrt.py index f59f8086a..47dae6a74 100644 --- a/tools/Polygraphy/polygraphy/datatype/tensorrt.py +++ b/tools/Polygraphy/polygraphy/datatype/tensorrt.py @@ -37,6 +37,7 @@ def _get_mapping(): util.try_getattr(trt, "bfloat16"): DataType.BFLOAT16, util.try_getattr(trt, "fp8"): DataType.FLOAT8E4M3FN, util.try_getattr(trt, "int4"): DataType.INT4, + util.try_getattr(trt, "fp4"): DataType.FLOAT4, } if None in DATATYPE_FROM_TENSORRT: del DATATYPE_FROM_TENSORRT[None] diff --git a/tools/Polygraphy/tests/common/test_datatype.py b/tools/Polygraphy/tests/common/test_datatype.py index c1515e18e..36df7d779 100644 --- a/tools/Polygraphy/tests/common/test_datatype.py +++ b/tools/Polygraphy/tests/common/test_datatype.py @@ -45,6 +45,7 @@ def test_numpy(self, dtype): DataType.FLOAT8E5M2, DataType.FLOAT8E5M2FNUZ, DataType.INT4, + DataType.FLOAT4, ]: pytest.xfail("Type not supported by NumPy") @@ -61,6 +62,7 @@ def test_numpy(self, dtype): def test_onnxrt(self, dtype): if dtype in [ DataType.INT4, + DataType.FLOAT4, ]: pytest.skip("Type not supported by ONNX-RT") @@ -84,6 +86,7 @@ def test_onnxrt(self, dtype): def test_onnx(self, dtype): if dtype in [ DataType.INT4, + DataType.FLOAT4, ]: pytest.skip("Type not supported by ONNX") @@ -137,6 +140,7 @@ def test_tensorrt(self, dtype): "float": "float32", "half": "float16", "fp8": "float8e4m3fn", + "fp4": "float4", "bf16": "bfloat16", }, ) @@ -162,6 +166,7 @@ def test_torch(self, dtype): DataType.UINT64, DataType.STRING, DataType.INT4, + DataType.FLOAT4, ]: pytest.xfail("Type not supported by Torch") diff --git a/tools/onnx-graphsurgeon/CHANGELOG.md b/tools/onnx-graphsurgeon/CHANGELOG.md index 7eaed393e..ee3000b5c 100644 --- a/tools/onnx-graphsurgeon/CHANGELOG.md +++ b/tools/onnx-graphsurgeon/CHANGELOG.md @@ -2,7 +2,7 @@ Dates are in YYYY-MM-DD format. -## v0.5.3 (TBD) +## v0.5.3 (2024-10-14) ### Added - Added `export_dtype` field to `gs.Constant` to allow numpy-unsupported dtypes such as BFloat16. diff --git a/tools/onnx-graphsurgeon/onnx_graphsurgeon/__init__.py b/tools/onnx-graphsurgeon/onnx_graphsurgeon/__init__.py index 6756daa1c..32367d7ec 100644 --- a/tools/onnx-graphsurgeon/onnx_graphsurgeon/__init__.py +++ b/tools/onnx-graphsurgeon/onnx_graphsurgeon/__init__.py @@ -7,4 +7,4 @@ from onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable from onnx_graphsurgeon.util.exception import OnnxGraphSurgeonException -__version__ = "0.5.2" +__version__ = "0.5.3" diff --git a/tools/onnx-graphsurgeon/tests/test_examples.py b/tools/onnx-graphsurgeon/tests/test_examples.py index fd86fbb0e..345b9f388 100644 --- a/tools/onnx-graphsurgeon/tests/test_examples.py +++ b/tools/onnx-graphsurgeon/tests/test_examples.py @@ -51,9 +51,11 @@ def __init__(self, name, infer=True): ("09_shape_operations_with_the_layer_api", [Artifact("model.onnx")]), ("10_dynamic_batch_size", [Artifact("model.onnx"), Artifact("dynamic.onnx")]), ("11_creating_a_local_function", [Artifact("model.onnx")]), - # Skipping inference test as bf16 is not supported in ORT yet. - ("12_using_bf16", [Artifact("test_conv_bf16.onnx", infer=False)]), + ( + "12_using_numpy_unsupported_dtypes", + [Artifact("test_conv_bf16.onnx", infer=False)], + ), ]