From fef83f3314721c7d7854ee7cfc3d8acac150a75b Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 12 May 2025 15:51:13 -0700 Subject: [PATCH 1/9] gate tensorrt plugin --- py/torch_tensorrt/_features.py | 18 +- .../conversion/plugins/_generate_plugin.py | 362 +++++++++--------- .../plugins/_generate_plugin_converter.py | 203 +++++----- 3 files changed, 309 insertions(+), 274 deletions(-) diff --git a/py/torch_tensorrt/_features.py b/py/torch_tensorrt/_features.py index bee0c3dbf0..abbd5c976f 100644 --- a/py/torch_tensorrt/_features.py +++ b/py/torch_tensorrt/_features.py @@ -15,6 +15,7 @@ "dynamo_frontend", "fx_frontend", "refit", + "tensorrt_plugin", ], ) @@ -39,14 +40,27 @@ _FX_FE_AVAIL = True _REFIT_AVAIL = True +try: + import tensorrt.plugin as trtp + + assert trtp + _TENSORRT_PLUGIN_AVAIL = True +except ImportError: + _TENSORRT_PLUGIN_AVAIL = False + ENABLED_FEATURES = FeatureSet( - _TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL + _TS_FE_AVAIL, + _TORCHTRT_RT_AVAIL, + _DYNAMO_FE_AVAIL, + _FX_FE_AVAIL, + _REFIT_AVAIL, + _TENSORRT_PLUGIN_AVAIL, ) def _enabled_features_str() -> str: enabled = lambda x: "ENABLED" if x else "DISABLED" - out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n" # type: ignore[no-untyped-call] + out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - TensorRT Plugin: {enabled(_TENSORRT_PLUGIN_AVAIL)}\n" # type: ignore[no-untyped-call] return out_str diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py index 8f5f173a7b..e63e43a7f5 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py @@ -3,7 +3,6 @@ from types import FunctionType from typing import Any, Callable, Tuple -import tensorrt.plugin as trtp import torch from sympy import lambdify from torch._dynamo.source import LocalSource @@ -12,211 +11,222 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) +try: + import tensorrt.plugin as trtp -def mksym( - shape_env: ShapeEnv, value: int, source: LocalSource, dynamic_dim: DimDynamic -) -> torch.SymInt: - return shape_env.create_symintnode( - shape_env.create_symbol( - value, - source=source, - dynamic_dim=dynamic_dim, - ), - hint=value, - source=source, + assert trtp +except ImportError as e: + _LOGGER.warning( + "Unable to import TensorRT plugin. Please install TensorRT plugin library (https://github.com/NVIDIA/TensorRT-plugin-library?tab=readme-ov-file#installation) to add support for compiling quantized models" ) +else: + + def mksym( + shape_env: ShapeEnv, value: int, source: LocalSource, dynamic_dim: DimDynamic + ) -> torch.SymInt: + return shape_env.create_symintnode( + shape_env.create_symbol( + value, + source=source, + dynamic_dim=dynamic_dim, + ), + hint=value, + source=source, + ) + def _generate_plugin(plugin_name: str) -> None: + namespace, name = plugin_name.split("::") + + # retrieve the corresponding torch operation using the passed in string + torch_op = getattr(getattr(torch.ops, namespace), name) + + # helper function that generates the required signature based on the torch operation + def generate_signature( + torch_op: Callable[[Any], Any], + ) -> Tuple[str, str, str, dict[str, Any], dict[str, Any]]: + schema = torch_op._schemas[""] + + arg_list = [] + + register_func_annotation = {} + impl_func_annotation = {} + + for arg in schema.arguments: + arg_list.append(arg.name) + + # TODO: Torch types need to be converted to python primitive types here + # Some other types are not handled: + # - torch._C.ListType.ofT() + # - torch._C.TupleType.get() + # - torch._C.DictType.get(, ) + # - torch._C.OptionalType.ofT() + # - torch._C.DeviceObjType.get() + # - torch._C.FunctionType.get() + # - torch._C.ClassType + + if arg.type.isSubtypeOf(torch._C.TensorType.get()): + register_func_annotation[arg.name] = trtp.TensorDesc + impl_func_annotation[arg.name] = trtp.Tensor + elif arg.type.isSubtypeOf(torch._C.FloatType.get()): + register_func_annotation[arg.name] = float + impl_func_annotation[arg.name] = float + elif arg.type.isSubtypeOf(torch._C.IntType.get()): + register_func_annotation[arg.name] = int + impl_func_annotation[arg.name] = int + elif arg.type.isSubtypeOf(torch._C.Booltype.get()): + register_func_annotation[arg.name] = bool + impl_func_annotation[arg.name] = bool + elif arg.type.isSubtypeOf(torch._C.Stringtype.get()): + register_func_annotation[arg.name] = str + impl_func_annotation[arg.name] = str + else: + raise ValueError("arg type is not handled") + + input_signature = ", ".join(arg_list) + + plugin_signature = f"def add_plugin_desc({input_signature}):" + + plugin_impl_arg_list = arg_list + plugin_impl_arg_list.append("outputs") + plugin_impl_arg_list.append("stream") + plugin_impl_input = ", ".join(plugin_impl_arg_list) + plugin_impl_signature = f"def add_plugin_impl({plugin_impl_input}):" + + register_func_annotation["return"] = Tuple[trtp.TensorDesc] + + impl_func_annotation["outputs"] = Tuple[trtp.Tensor] + impl_func_annotation["stream"] = int + + return ( + input_signature, + plugin_signature, + plugin_impl_signature, + register_func_annotation, + impl_func_annotation, + ) -def _generate_plugin(plugin_name: str) -> None: - namespace, name = plugin_name.split("::") - - # retrieve the corresponding torch operation using the passed in string - torch_op = getattr(getattr(torch.ops, namespace), name) - - # helper function that generates the required signature based on the torch operation - def generate_signature( - torch_op: Callable[[Any], Any], - ) -> Tuple[str, str, str, dict[str, Any], dict[str, Any]]: - schema = torch_op._schemas[""] - - arg_list = [] - - register_func_annotation = {} - impl_func_annotation = {} - - for arg in schema.arguments: - arg_list.append(arg.name) - - # TODO: Torch types need to be converted to python primitive types here - # Some other types are not handled: - # - torch._C.ListType.ofT() - # - torch._C.TupleType.get() - # - torch._C.DictType.get(, ) - # - torch._C.OptionalType.ofT() - # - torch._C.DeviceObjType.get() - # - torch._C.FunctionType.get() - # - torch._C.ClassType - - if arg.type.isSubtypeOf(torch._C.TensorType.get()): - register_func_annotation[arg.name] = trtp.TensorDesc - impl_func_annotation[arg.name] = trtp.Tensor - elif arg.type.isSubtypeOf(torch._C.FloatType.get()): - register_func_annotation[arg.name] = float - impl_func_annotation[arg.name] = float - elif arg.type.isSubtypeOf(torch._C.IntType.get()): - register_func_annotation[arg.name] = int - impl_func_annotation[arg.name] = int - elif arg.type.isSubtypeOf(torch._C.Booltype.get()): - register_func_annotation[arg.name] = bool - impl_func_annotation[arg.name] = bool - elif arg.type.isSubtypeOf(torch._C.Stringtype.get()): - register_func_annotation[arg.name] = str - impl_func_annotation[arg.name] = str - else: - raise ValueError("arg type is not handled") - - input_signature = ", ".join(arg_list) - - plugin_signature = f"def add_plugin_desc({input_signature}):" - - plugin_impl_arg_list = arg_list - plugin_impl_arg_list.append("outputs") - plugin_impl_arg_list.append("stream") - plugin_impl_input = ", ".join(plugin_impl_arg_list) - plugin_impl_signature = f"def add_plugin_impl({plugin_impl_input}):" - - register_func_annotation["return"] = Tuple[trtp.TensorDesc] - - impl_func_annotation["outputs"] = Tuple[trtp.Tensor] - impl_func_annotation["stream"] = int - - return ( + # Use the helper function to get the required signatures + ( input_signature, plugin_signature, plugin_impl_signature, register_func_annotation, impl_func_annotation, - ) - - # Use the helper function to get the required signatures - ( - input_signature, - plugin_signature, - plugin_impl_signature, - register_func_annotation, - impl_func_annotation, - ) = generate_signature(torch_op) - - def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]: - shape_env = ShapeEnv() - syms_args = [] - tensor_args = [elem for elem in args if isinstance(elem, trtp.TensorDesc)] - - for tensor_arg in tensor_args: - - sample = {f"{i}": 5 for i in range(tensor_arg.ndim)} - syms_arg = [ - mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) - for k, v in sample.items() - ] - syms_args.append(syms_arg) - - with FakeTensorMode(shape_env=shape_env) as fake_mode: - fake_args = [] - for syms_arg in syms_args: - fake_arg = torch.randn(syms_arg) - fake_args.append(fake_arg) - - output = torch_op(*fake_args, **kwargs) - - # We assume that number of dimensions are the same in torch op - shape_calc_fns = [None] * output.ndim - - for i in range(output.ndim): - input_node_expr = list( - itertools.chain.from_iterable( - [sym.node.expr for sym in syms_arg] for syms_arg in syms_args + ) = generate_signature(torch_op) + + def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]: + shape_env = ShapeEnv() + syms_args = [] + tensor_args = [elem for elem in args if isinstance(elem, trtp.TensorDesc)] + + for tensor_arg in tensor_args: + + sample = {f"{i}": 5 for i in range(tensor_arg.ndim)} + syms_arg = [ + mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) + for k, v in sample.items() + ] + syms_args.append(syms_arg) + + with FakeTensorMode(shape_env=shape_env) as fake_mode: + fake_args = [] + for syms_arg in syms_args: + fake_arg = torch.randn(syms_arg) + fake_args.append(fake_arg) + + output = torch_op(*fake_args, **kwargs) + + # We assume that number of dimensions are the same in torch op + shape_calc_fns = [None] * output.ndim + + for i in range(output.ndim): + input_node_expr = list( + itertools.chain.from_iterable( + [sym.node.expr for sym in syms_arg] for syms_arg in syms_args + ) ) - ) - - shape_calc_fns[i] = lambdify( - tuple(input_node_expr), output.shape[i].node.expr, "math" - ) - out_desc = tensor_args[0].like() - for i in range(out_desc.ndim): - input_shape_expr = list( - itertools.chain.from_iterable(arg.shape_expr for arg in tensor_args) - ) + shape_calc_fns[i] = lambdify( + tuple(input_node_expr), output.shape[i].node.expr, "math" + ) - if output.shape[i].node.expr is None: - raise ValueError(f"output.shape[{i}].node.expr cannot be None") - out_desc.shape_expr[i] = shape_calc_fns[i](*input_shape_expr) # type: ignore[misc] + out_desc = tensor_args[0].like() + for i in range(out_desc.ndim): + input_shape_expr = list( + itertools.chain.from_iterable(arg.shape_expr for arg in tensor_args) + ) - return (out_desc,) + if output.shape[i].node.expr is None: + raise ValueError(f"output.shape[{i}].node.expr cannot be None") + out_desc.shape_expr[i] = shape_calc_fns[i](*input_shape_expr) # type: ignore[misc] - codegen_plugin = f""" -{plugin_signature} - return _generic_plugin_desc({input_signature}) - """ + return (out_desc,) - _LOGGER.warning(f"Plugin registration function: \n{codegen_plugin}") + codegen_plugin = f""" + {plugin_signature} + return _generic_plugin_desc({input_signature}) + """ - plugin_code = compile(codegen_plugin, "", "exec") + _LOGGER.warning(f"Plugin registration function: \n{codegen_plugin}") - globals()["_generic_plugin_desc"] = _generic_plugin_desc + plugin_code = compile(codegen_plugin, "", "exec") - plugin = FunctionType( - plugin_code.co_consts[0], - globals(), - "plugin", - ) + globals()["_generic_plugin_desc"] = _generic_plugin_desc - # Function annotation is required for dynamic function to work in TensorRT.Plugin - plugin.__annotations__ = register_func_annotation + plugin = FunctionType( + plugin_code.co_consts[0], + globals(), + "plugin", + ) - trtp.register(plugin_name)(plugin) + # Function annotation is required for dynamic function to work in TensorRT.Plugin + plugin.__annotations__ = register_func_annotation - def _generic_plugin_impl( - outputs: Tuple[trtp.Tensor], stream: int, *args: Any, **kwargs: Any - ) -> None: - tensor_args = [elem for elem in args if isinstance(elem, trtp.Tensor)] - non_tensor_args = [elem for elem in args if not isinstance(elem, trtp.Tensor)] - in_tensors = [torch.as_tensor(i, device="cuda") for i in tensor_args] + trtp.register(plugin_name)(plugin) - dest_tensors = [torch.as_tensor(o, device="cuda") for o in outputs] + def _generic_plugin_impl( + outputs: Tuple[trtp.Tensor], stream: int, *args: Any, **kwargs: Any + ) -> None: + tensor_args = [elem for elem in args if isinstance(elem, trtp.Tensor)] + non_tensor_args = [ + elem for elem in args if not isinstance(elem, trtp.Tensor) + ] + in_tensors = [torch.as_tensor(i, device="cuda") for i in tensor_args] - stream = torch.cuda.ExternalStream(stream) - with torch.cuda.stream(stream): - out_tensors = torch_op(*in_tensors, *non_tensor_args, **kwargs) - if isinstance(out_tensors, torch.Tensor): - out_tensors = (out_tensors,) - [d.copy_(o) for (d, o) in zip(dest_tensors, out_tensors)] + dest_tensors = [torch.as_tensor(o, device="cuda") for o in outputs] - plugin_impl_func = f""" -{plugin_impl_signature} - _generic_plugin_impl(outputs, stream, {input_signature}) - """ + stream = torch.cuda.ExternalStream(stream) + with torch.cuda.stream(stream): + out_tensors = torch_op(*in_tensors, *non_tensor_args, **kwargs) + if isinstance(out_tensors, torch.Tensor): + out_tensors = (out_tensors,) + [d.copy_(o) for (d, o) in zip(dest_tensors, out_tensors)] - _LOGGER.warning(f"Plugin implementation function: \n{plugin_impl_func}") + plugin_impl_func = f""" + {plugin_impl_signature} + _generic_plugin_impl(outputs, stream, {input_signature}) + """ - plugin_impl_code = compile(plugin_impl_func, "", "exec") + _LOGGER.warning(f"Plugin implementation function: \n{plugin_impl_func}") - globals()["_generic_plugin_impl"] = _generic_plugin_impl + plugin_impl_code = compile(plugin_impl_func, "", "exec") - plugin_impl = FunctionType(plugin_impl_code.co_consts[0], globals(), "plugin_impl") + globals()["_generic_plugin_impl"] = _generic_plugin_impl - plugin_impl.__annotations__ = impl_func_annotation + plugin_impl = FunctionType( + plugin_impl_code.co_consts[0], globals(), "plugin_impl" + ) - trtp.impl(plugin_name)(plugin_impl) + plugin_impl.__annotations__ = impl_func_annotation + trtp.impl(plugin_name)(plugin_impl) -def generate_plugin(plugin_name: str) -> None: - """ - Generate the Plugin using external kernels and TensorRT Quick Deployable Plugin APIs. + def generate_plugin(plugin_name: str) -> None: + """ + Generate the Plugin using external kernels and TensorRT Quick Deployable Plugin APIs. - Args: - plugin_name: the plugin name that is used to generate the plugin automatically. - There should be existing kernels and pytorch custom operation for this plugin name. - """ - _generate_plugin(plugin_name) + Args: + plugin_name: the plugin name that is used to generate the plugin automatically. + There should be existing kernels and pytorch custom operation for this plugin name. + """ + _generate_plugin(plugin_name) diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py index 99ea3bc356..f6a44ba121 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py @@ -4,9 +4,6 @@ import numpy as np import tensorrt as trt - -# Seems like a bug in TensorRT -import tensorrt.plugin as trtp import torch from tensorrt.plugin._lib import QDP_REGISTRY from torch.fx.node import Argument, Node, Target @@ -22,98 +19,112 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) +try: + import tensorrt.plugin as trtp -def _generate_plugin_converter( - namespace: str, - op_name: str, - overload: Optional[str] = None, - capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None, - priority: ConverterPriority = ConverterPriority.STANDARD, - supports_dynamic_shapes: bool = False, - requires_output_allocator: bool = False, -) -> DynamoConverterImplSignature: - torch_target = getattr(getattr(torch.ops, namespace), op_name) - overload_str = overload if overload else "" - overload_name = overload_str if overload else "default" - torch_overload = getattr(torch_target, overload_name) - assert ( - f"{namespace}::{op_name}" in QDP_REGISTRY - ), f"Could not find a tensorrt plugin registered for op {namespace}::{op_name}, unable to generate converter" - torch_schema = torch_target._schemas[overload_str] - - def custom_kernel_converter( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, - ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: - - plugin = getattr(getattr(trtp.op, namespace), op_name) - - tensor_inputs = plugin.input_tensor_names - tensor_args = args[0 : len(tensor_inputs)] - - unique_id = uuid.uuid4() - itensor_args = [ - get_trt_tensor(ctx, t, f"{t_name}_{unique_id}") - for (t, t_name) in zip(tensor_args, tensor_inputs) - ] - - # Assuming TensorRT preserves kwargs order like PyTorch does - non_tensor_inputs = plugin.input_attrs - - kwargs = {} - - for arg in torch_schema.arguments: - if arg.default_value is not None: - kwargs[arg.name] = arg.default_value - - non_tensor_args = args[len(tensor_inputs) :] - non_tensor_kwargs = dict(zip(list(non_tensor_inputs.keys()), non_tensor_args)) - - for k, v in kwargs.items(): - if k in non_tensor_kwargs: - kwargs[k] = non_tensor_kwargs[k] - - for k, v in kwargs.items(): - if isinstance(v, torch.fx.immutable_collections.immutable_list): - kwargs[k] = np.array(v) - - layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs)) - assert layer, f"{namespace}::{name} plugin layer was not able to be created" - _LOGGER.debug( - f"Adding generated plugin for {namespace}::{name} to tensorrt network" - ) - layer.name = f"[{target}]-[{name}]" - return layer.get_output(0) - - custom_kernel_converter = dynamo_tensorrt_converter( - torch_overload, - capability_validator=capability_validator, - priority=priority, - supports_dynamic_shapes=supports_dynamic_shapes, - requires_output_allocator=requires_output_allocator, - )(custom_kernel_converter) - assert ( - torch_overload in DYNAMO_CONVERTERS - ), f"Generated dynamo converter for {namespace}::{op_name} did not get properly registered in the converter registry" - return custom_kernel_converter - - -def generate_plugin_converter( - plugin_id: str, - capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None, - priority: ConverterPriority = ConverterPriority.STANDARD, - supports_dynamic_shapes: bool = False, - requires_output_allocator: bool = False, -) -> DynamoConverterImplSignature: - plugin_ns, plugin_name = plugin_id.split("::") - return _generate_plugin_converter( - plugin_ns, - plugin_name, - capability_validator=capability_validator, - priority=priority, - supports_dynamic_shapes=supports_dynamic_shapes, - requires_output_allocator=requires_output_allocator, + assert trtp +except ImportError as e: + _LOGGER.warning( + "Unable to import TensorRT plugin. Please install TensorRT plugin library (https://github.com/NVIDIA/TensorRT-plugin-library?tab=readme-ov-file#installation) to add support for compiling quantized models" ) +else: + + def _generate_plugin_converter( + namespace: str, + op_name: str, + overload: Optional[str] = None, + capability_validator: Optional[ + Callable[[Node, CompilationSettings], bool] + ] = None, + priority: ConverterPriority = ConverterPriority.STANDARD, + supports_dynamic_shapes: bool = False, + requires_output_allocator: bool = False, + ) -> DynamoConverterImplSignature: + torch_target = getattr(getattr(torch.ops, namespace), op_name) + overload_str = overload if overload else "" + overload_name = overload_str if overload else "default" + torch_overload = getattr(torch_target, overload_name) + assert ( + f"{namespace}::{op_name}" in QDP_REGISTRY + ), f"Could not find a tensorrt plugin registered for op {namespace}::{op_name}, unable to generate converter" + torch_schema = torch_target._schemas[overload_str] + + def custom_kernel_converter( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + + plugin = getattr(getattr(trtp.op, namespace), op_name) + + tensor_inputs = plugin.input_tensor_names + tensor_args = args[0 : len(tensor_inputs)] + + unique_id = uuid.uuid4() + itensor_args = [ + get_trt_tensor(ctx, t, f"{t_name}_{unique_id}") + for (t, t_name) in zip(tensor_args, tensor_inputs) + ] + + # Assuming TensorRT preserves kwargs order like PyTorch does + non_tensor_inputs = plugin.input_attrs + + kwargs = {} + + for arg in torch_schema.arguments: + if arg.default_value is not None: + kwargs[arg.name] = arg.default_value + + non_tensor_args = args[len(tensor_inputs) :] + non_tensor_kwargs = dict( + zip(list(non_tensor_inputs.keys()), non_tensor_args) + ) + + for k, v in kwargs.items(): + if k in non_tensor_kwargs: + kwargs[k] = non_tensor_kwargs[k] + + for k, v in kwargs.items(): + if isinstance(v, torch.fx.immutable_collections.immutable_list): + kwargs[k] = np.array(v) + + layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs)) + assert layer, f"{namespace}::{name} plugin layer was not able to be created" + _LOGGER.debug( + f"Adding generated plugin for {namespace}::{name} to tensorrt network" + ) + layer.name = f"[{target}]-[{name}]" + return layer.get_output(0) + + custom_kernel_converter = dynamo_tensorrt_converter( + torch_overload, + capability_validator=capability_validator, + priority=priority, + supports_dynamic_shapes=supports_dynamic_shapes, + requires_output_allocator=requires_output_allocator, + )(custom_kernel_converter) + assert ( + torch_overload in DYNAMO_CONVERTERS + ), f"Generated dynamo converter for {namespace}::{op_name} did not get properly registered in the converter registry" + return custom_kernel_converter + + def generate_plugin_converter( + plugin_id: str, + capability_validator: Optional[ + Callable[[Node, CompilationSettings], bool] + ] = None, + priority: ConverterPriority = ConverterPriority.STANDARD, + supports_dynamic_shapes: bool = False, + requires_output_allocator: bool = False, + ) -> DynamoConverterImplSignature: + plugin_ns, plugin_name = plugin_id.split("::") + return _generate_plugin_converter( + plugin_ns, + plugin_name, + capability_validator=capability_validator, + priority=priority, + supports_dynamic_shapes=supports_dynamic_shapes, + requires_output_allocator=requires_output_allocator, + ) From 36a7010c701c9dc867e32c3125b2d6bcf0e22349 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 19 May 2025 20:33:36 -0700 Subject: [PATCH 2/9] address comments from Naren --- py/torch_tensorrt/_features.py | 32 +- .../conversion/plugins/_generate_plugin.py | 370 +++++++++--------- .../plugins/_generate_plugin_converter.py | 209 +++++----- 3 files changed, 308 insertions(+), 303 deletions(-) diff --git a/py/torch_tensorrt/_features.py b/py/torch_tensorrt/_features.py index abbd5c976f..09d1038cc6 100644 --- a/py/torch_tensorrt/_features.py +++ b/py/torch_tensorrt/_features.py @@ -1,3 +1,4 @@ +import importlib import os import sys from collections import namedtuple @@ -15,7 +16,7 @@ "dynamo_frontend", "fx_frontend", "refit", - "tensorrt_plugin", + "qdp_plugin", ], ) @@ -40,13 +41,10 @@ _FX_FE_AVAIL = True _REFIT_AVAIL = True -try: - import tensorrt.plugin as trtp - - assert trtp - _TENSORRT_PLUGIN_AVAIL = True -except ImportError: - _TENSORRT_PLUGIN_AVAIL = False +if importlib.util.find_spec("tensorrt.plugin"): + _QDP_PLUGIN_AVAIL = True +else: + _QDP_PLUGIN_AVAIL = False ENABLED_FEATURES = FeatureSet( _TS_FE_AVAIL, @@ -54,13 +52,13 @@ _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL, - _TENSORRT_PLUGIN_AVAIL, + _QDP_PLUGIN_AVAIL, ) def _enabled_features_str() -> str: enabled = lambda x: "ENABLED" if x else "DISABLED" - out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - TensorRT Plugin: {enabled(_TENSORRT_PLUGIN_AVAIL)}\n" # type: ignore[no-untyped-call] + out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)}\n" # type: ignore[no-untyped-call] return out_str @@ -78,6 +76,20 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: return wrapper +def needs_qdp_plugin(f: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: + if ENABLED_FEATURES.qdp_plugin: + return f(*args, **kwargs) + else: + + def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: + raise NotImplementedError("QDP Plugin is not available") + + return not_implemented(*args, **kwargs) + + return wrapper + + def needs_refit(f: Callable[..., Any]) -> Callable[..., Any]: def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: if ENABLED_FEATURES.refit: diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py index e63e43a7f5..571e6afdcf 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py @@ -8,225 +8,223 @@ from torch._dynamo.source import LocalSource from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv +from torch_tensorrt._features import needs_qdp_plugin _LOGGER: logging.Logger = logging.getLogger(__name__) -try: - import tensorrt.plugin as trtp - assert trtp -except ImportError as e: - _LOGGER.warning( - "Unable to import TensorRT plugin. Please install TensorRT plugin library (https://github.com/NVIDIA/TensorRT-plugin-library?tab=readme-ov-file#installation) to add support for compiling quantized models" - ) -else: - - def mksym( - shape_env: ShapeEnv, value: int, source: LocalSource, dynamic_dim: DimDynamic - ) -> torch.SymInt: - return shape_env.create_symintnode( - shape_env.create_symbol( - value, - source=source, - dynamic_dim=dynamic_dim, - ), - hint=value, +def mksym( + shape_env: ShapeEnv, value: int, source: LocalSource, dynamic_dim: DimDynamic +) -> torch.SymInt: + return shape_env.create_symintnode( + shape_env.create_symbol( + value, source=source, - ) + dynamic_dim=dynamic_dim, + ), + hint=value, + source=source, + ) - def _generate_plugin(plugin_name: str) -> None: - namespace, name = plugin_name.split("::") - - # retrieve the corresponding torch operation using the passed in string - torch_op = getattr(getattr(torch.ops, namespace), name) - - # helper function that generates the required signature based on the torch operation - def generate_signature( - torch_op: Callable[[Any], Any], - ) -> Tuple[str, str, str, dict[str, Any], dict[str, Any]]: - schema = torch_op._schemas[""] - - arg_list = [] - - register_func_annotation = {} - impl_func_annotation = {} - - for arg in schema.arguments: - arg_list.append(arg.name) - - # TODO: Torch types need to be converted to python primitive types here - # Some other types are not handled: - # - torch._C.ListType.ofT() - # - torch._C.TupleType.get() - # - torch._C.DictType.get(, ) - # - torch._C.OptionalType.ofT() - # - torch._C.DeviceObjType.get() - # - torch._C.FunctionType.get() - # - torch._C.ClassType - - if arg.type.isSubtypeOf(torch._C.TensorType.get()): - register_func_annotation[arg.name] = trtp.TensorDesc - impl_func_annotation[arg.name] = trtp.Tensor - elif arg.type.isSubtypeOf(torch._C.FloatType.get()): - register_func_annotation[arg.name] = float - impl_func_annotation[arg.name] = float - elif arg.type.isSubtypeOf(torch._C.IntType.get()): - register_func_annotation[arg.name] = int - impl_func_annotation[arg.name] = int - elif arg.type.isSubtypeOf(torch._C.Booltype.get()): - register_func_annotation[arg.name] = bool - impl_func_annotation[arg.name] = bool - elif arg.type.isSubtypeOf(torch._C.Stringtype.get()): - register_func_annotation[arg.name] = str - impl_func_annotation[arg.name] = str - else: - raise ValueError("arg type is not handled") - - input_signature = ", ".join(arg_list) - - plugin_signature = f"def add_plugin_desc({input_signature}):" - - plugin_impl_arg_list = arg_list - plugin_impl_arg_list.append("outputs") - plugin_impl_arg_list.append("stream") - plugin_impl_input = ", ".join(plugin_impl_arg_list) - plugin_impl_signature = f"def add_plugin_impl({plugin_impl_input}):" - - register_func_annotation["return"] = Tuple[trtp.TensorDesc] - - impl_func_annotation["outputs"] = Tuple[trtp.Tensor] - impl_func_annotation["stream"] = int - - return ( - input_signature, - plugin_signature, - plugin_impl_signature, - register_func_annotation, - impl_func_annotation, - ) - # Use the helper function to get the required signatures - ( +def _generate_plugin(plugin_name: str) -> None: + try: + import tensorrt.plugin as trtp + except ImportError as e: + raise RuntimeError( + "Unable to import TensorRT plugin. Please install TensorRT plugin library (https://github.com/NVIDIA/TensorRT-plugin-library?tab=readme-ov-file#installation) to add support for compiling quantized models" + ) + + namespace, name = plugin_name.split("::") + + # retrieve the corresponding torch operation using the passed in string + torch_op = getattr(getattr(torch.ops, namespace), name) + + # helper function that generates the required signature based on the torch operation + def generate_signature( + torch_op: Callable[[Any], Any], + ) -> Tuple[str, str, str, dict[str, Any], dict[str, Any]]: + schema = torch_op._schemas[""] + + arg_list = [] + + register_func_annotation = {} + impl_func_annotation = {} + + for arg in schema.arguments: + arg_list.append(arg.name) + + # TODO: Torch types need to be converted to python primitive types here + # Some other types are not handled: + # - torch._C.ListType.ofT() + # - torch._C.TupleType.get() + # - torch._C.DictType.get(, ) + # - torch._C.OptionalType.ofT() + # - torch._C.DeviceObjType.get() + # - torch._C.FunctionType.get() + # - torch._C.ClassType + + if arg.type.isSubtypeOf(torch._C.TensorType.get()): + register_func_annotation[arg.name] = trtp.TensorDesc + impl_func_annotation[arg.name] = trtp.Tensor + elif arg.type.isSubtypeOf(torch._C.FloatType.get()): + register_func_annotation[arg.name] = float + impl_func_annotation[arg.name] = float + elif arg.type.isSubtypeOf(torch._C.IntType.get()): + register_func_annotation[arg.name] = int + impl_func_annotation[arg.name] = int + elif arg.type.isSubtypeOf(torch._C.Booltype.get()): + register_func_annotation[arg.name] = bool + impl_func_annotation[arg.name] = bool + elif arg.type.isSubtypeOf(torch._C.Stringtype.get()): + register_func_annotation[arg.name] = str + impl_func_annotation[arg.name] = str + else: + raise ValueError("arg type is not handled") + + input_signature = ", ".join(arg_list) + + plugin_signature = f"def add_plugin_desc({input_signature}):" + + plugin_impl_arg_list = arg_list + plugin_impl_arg_list.append("outputs") + plugin_impl_arg_list.append("stream") + plugin_impl_input = ", ".join(plugin_impl_arg_list) + plugin_impl_signature = f"def add_plugin_impl({plugin_impl_input}):" + + register_func_annotation["return"] = Tuple[trtp.TensorDesc] + + impl_func_annotation["outputs"] = Tuple[trtp.Tensor] + impl_func_annotation["stream"] = int + + return ( input_signature, plugin_signature, plugin_impl_signature, register_func_annotation, impl_func_annotation, - ) = generate_signature(torch_op) - - def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]: - shape_env = ShapeEnv() - syms_args = [] - tensor_args = [elem for elem in args if isinstance(elem, trtp.TensorDesc)] - - for tensor_arg in tensor_args: - - sample = {f"{i}": 5 for i in range(tensor_arg.ndim)} - syms_arg = [ - mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) - for k, v in sample.items() - ] - syms_args.append(syms_arg) - - with FakeTensorMode(shape_env=shape_env) as fake_mode: - fake_args = [] - for syms_arg in syms_args: - fake_arg = torch.randn(syms_arg) - fake_args.append(fake_arg) - - output = torch_op(*fake_args, **kwargs) - - # We assume that number of dimensions are the same in torch op - shape_calc_fns = [None] * output.ndim - - for i in range(output.ndim): - input_node_expr = list( - itertools.chain.from_iterable( - [sym.node.expr for sym in syms_arg] for syms_arg in syms_args - ) - ) + ) - shape_calc_fns[i] = lambdify( - tuple(input_node_expr), output.shape[i].node.expr, "math" - ) + # Use the helper function to get the required signatures + ( + input_signature, + plugin_signature, + plugin_impl_signature, + register_func_annotation, + impl_func_annotation, + ) = generate_signature(torch_op) + + def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]: + shape_env = ShapeEnv() + syms_args = [] + tensor_args = [elem for elem in args if isinstance(elem, trtp.TensorDesc)] + + for tensor_arg in tensor_args: + + sample = {f"{i}": 5 for i in range(tensor_arg.ndim)} + syms_arg = [ + mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) + for k, v in sample.items() + ] + syms_args.append(syms_arg) + + with FakeTensorMode(shape_env=shape_env) as fake_mode: + fake_args = [] + for syms_arg in syms_args: + fake_arg = torch.randn(syms_arg) + fake_args.append(fake_arg) + + output = torch_op(*fake_args, **kwargs) - out_desc = tensor_args[0].like() - for i in range(out_desc.ndim): - input_shape_expr = list( - itertools.chain.from_iterable(arg.shape_expr for arg in tensor_args) + # We assume that number of dimensions are the same in torch op + shape_calc_fns = [None] * output.ndim + + for i in range(output.ndim): + input_node_expr = list( + itertools.chain.from_iterable( + [sym.node.expr for sym in syms_arg] for syms_arg in syms_args ) + ) - if output.shape[i].node.expr is None: - raise ValueError(f"output.shape[{i}].node.expr cannot be None") - out_desc.shape_expr[i] = shape_calc_fns[i](*input_shape_expr) # type: ignore[misc] + shape_calc_fns[i] = lambdify( + tuple(input_node_expr), output.shape[i].node.expr, "math" + ) - return (out_desc,) + out_desc = tensor_args[0].like() + for i in range(out_desc.ndim): + input_shape_expr = list( + itertools.chain.from_iterable(arg.shape_expr for arg in tensor_args) + ) - codegen_plugin = f""" - {plugin_signature} - return _generic_plugin_desc({input_signature}) - """ + if output.shape[i].node.expr is None: + raise ValueError(f"output.shape[{i}].node.expr cannot be None") + out_desc.shape_expr[i] = shape_calc_fns[i](*input_shape_expr) # type: ignore[misc] - _LOGGER.warning(f"Plugin registration function: \n{codegen_plugin}") + return (out_desc,) - plugin_code = compile(codegen_plugin, "", "exec") + codegen_plugin = f""" +{plugin_signature} + return _generic_plugin_desc({input_signature}) + """ - globals()["_generic_plugin_desc"] = _generic_plugin_desc + _LOGGER.warning(f"Plugin registration function: \n{codegen_plugin}") - plugin = FunctionType( - plugin_code.co_consts[0], - globals(), - "plugin", - ) + plugin_code = compile(codegen_plugin, "", "exec") - # Function annotation is required for dynamic function to work in TensorRT.Plugin - plugin.__annotations__ = register_func_annotation + globals()["_generic_plugin_desc"] = _generic_plugin_desc - trtp.register(plugin_name)(plugin) + plugin = FunctionType( + plugin_code.co_consts[0], + globals(), + "plugin", + ) - def _generic_plugin_impl( - outputs: Tuple[trtp.Tensor], stream: int, *args: Any, **kwargs: Any - ) -> None: - tensor_args = [elem for elem in args if isinstance(elem, trtp.Tensor)] - non_tensor_args = [ - elem for elem in args if not isinstance(elem, trtp.Tensor) - ] - in_tensors = [torch.as_tensor(i, device="cuda") for i in tensor_args] + # Function annotation is required for dynamic function to work in TensorRT.Plugin + plugin.__annotations__ = register_func_annotation - dest_tensors = [torch.as_tensor(o, device="cuda") for o in outputs] + trtp.register(plugin_name)(plugin) - stream = torch.cuda.ExternalStream(stream) - with torch.cuda.stream(stream): - out_tensors = torch_op(*in_tensors, *non_tensor_args, **kwargs) - if isinstance(out_tensors, torch.Tensor): - out_tensors = (out_tensors,) - [d.copy_(o) for (d, o) in zip(dest_tensors, out_tensors)] + def _generic_plugin_impl( + outputs: Tuple[trtp.Tensor], stream: int, *args: Any, **kwargs: Any + ) -> None: + tensor_args = [elem for elem in args if isinstance(elem, trtp.Tensor)] + non_tensor_args = [elem for elem in args if not isinstance(elem, trtp.Tensor)] + in_tensors = [torch.as_tensor(i, device="cuda") for i in tensor_args] - plugin_impl_func = f""" - {plugin_impl_signature} - _generic_plugin_impl(outputs, stream, {input_signature}) - """ + dest_tensors = [torch.as_tensor(o, device="cuda") for o in outputs] - _LOGGER.warning(f"Plugin implementation function: \n{plugin_impl_func}") + stream = torch.cuda.ExternalStream(stream) + with torch.cuda.stream(stream): + out_tensors = torch_op(*in_tensors, *non_tensor_args, **kwargs) + if isinstance(out_tensors, torch.Tensor): + out_tensors = (out_tensors,) + [d.copy_(o) for (d, o) in zip(dest_tensors, out_tensors)] - plugin_impl_code = compile(plugin_impl_func, "", "exec") + plugin_impl_func = f""" +{plugin_impl_signature} + _generic_plugin_impl(outputs, stream, {input_signature}) + """ - globals()["_generic_plugin_impl"] = _generic_plugin_impl + _LOGGER.warning(f"Plugin implementation function: \n{plugin_impl_func}") - plugin_impl = FunctionType( - plugin_impl_code.co_consts[0], globals(), "plugin_impl" - ) + plugin_impl_code = compile(plugin_impl_func, "", "exec") + + globals()["_generic_plugin_impl"] = _generic_plugin_impl + + plugin_impl = FunctionType(plugin_impl_code.co_consts[0], globals(), "plugin_impl") + + plugin_impl.__annotations__ = impl_func_annotation - plugin_impl.__annotations__ = impl_func_annotation + trtp.impl(plugin_name)(plugin_impl) - trtp.impl(plugin_name)(plugin_impl) - def generate_plugin(plugin_name: str) -> None: - """ - Generate the Plugin using external kernels and TensorRT Quick Deployable Plugin APIs. +@needs_qdp_plugin +def generate_plugin(plugin_name: str) -> None: + """ + Generate the Plugin using external kernels and TensorRT Quick Deployable Plugin APIs. - Args: - plugin_name: the plugin name that is used to generate the plugin automatically. - There should be existing kernels and pytorch custom operation for this plugin name. - """ - _generate_plugin(plugin_name) + Args: + plugin_name: the plugin name that is used to generate the plugin automatically. + There should be existing kernels and pytorch custom operation for this plugin name. + """ + _generate_plugin(plugin_name) diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py index f6a44ba121..112e6da957 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py @@ -7,6 +7,7 @@ import torch from tensorrt.plugin._lib import QDP_REGISTRY from torch.fx.node import Argument, Node, Target +from torch_tensorrt._features import needs_qdp_plugin from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( @@ -19,112 +20,106 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -try: - import tensorrt.plugin as trtp - assert trtp -except ImportError as e: - _LOGGER.warning( - "Unable to import TensorRT plugin. Please install TensorRT plugin library (https://github.com/NVIDIA/TensorRT-plugin-library?tab=readme-ov-file#installation) to add support for compiling quantized models" - ) -else: - - def _generate_plugin_converter( - namespace: str, - op_name: str, - overload: Optional[str] = None, - capability_validator: Optional[ - Callable[[Node, CompilationSettings], bool] - ] = None, - priority: ConverterPriority = ConverterPriority.STANDARD, - supports_dynamic_shapes: bool = False, - requires_output_allocator: bool = False, - ) -> DynamoConverterImplSignature: - torch_target = getattr(getattr(torch.ops, namespace), op_name) - overload_str = overload if overload else "" - overload_name = overload_str if overload else "default" - torch_overload = getattr(torch_target, overload_name) - assert ( - f"{namespace}::{op_name}" in QDP_REGISTRY - ), f"Could not find a tensorrt plugin registered for op {namespace}::{op_name}, unable to generate converter" - torch_schema = torch_target._schemas[overload_str] - - def custom_kernel_converter( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, - ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: - - plugin = getattr(getattr(trtp.op, namespace), op_name) - - tensor_inputs = plugin.input_tensor_names - tensor_args = args[0 : len(tensor_inputs)] - - unique_id = uuid.uuid4() - itensor_args = [ - get_trt_tensor(ctx, t, f"{t_name}_{unique_id}") - for (t, t_name) in zip(tensor_args, tensor_inputs) - ] - - # Assuming TensorRT preserves kwargs order like PyTorch does - non_tensor_inputs = plugin.input_attrs - - kwargs = {} - - for arg in torch_schema.arguments: - if arg.default_value is not None: - kwargs[arg.name] = arg.default_value - - non_tensor_args = args[len(tensor_inputs) :] - non_tensor_kwargs = dict( - zip(list(non_tensor_inputs.keys()), non_tensor_args) - ) - - for k, v in kwargs.items(): - if k in non_tensor_kwargs: - kwargs[k] = non_tensor_kwargs[k] - - for k, v in kwargs.items(): - if isinstance(v, torch.fx.immutable_collections.immutable_list): - kwargs[k] = np.array(v) - - layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs)) - assert layer, f"{namespace}::{name} plugin layer was not able to be created" - _LOGGER.debug( - f"Adding generated plugin for {namespace}::{name} to tensorrt network" - ) - layer.name = f"[{target}]-[{name}]" - return layer.get_output(0) - - custom_kernel_converter = dynamo_tensorrt_converter( - torch_overload, - capability_validator=capability_validator, - priority=priority, - supports_dynamic_shapes=supports_dynamic_shapes, - requires_output_allocator=requires_output_allocator, - )(custom_kernel_converter) - assert ( - torch_overload in DYNAMO_CONVERTERS - ), f"Generated dynamo converter for {namespace}::{op_name} did not get properly registered in the converter registry" - return custom_kernel_converter - - def generate_plugin_converter( - plugin_id: str, - capability_validator: Optional[ - Callable[[Node, CompilationSettings], bool] - ] = None, - priority: ConverterPriority = ConverterPriority.STANDARD, - supports_dynamic_shapes: bool = False, - requires_output_allocator: bool = False, - ) -> DynamoConverterImplSignature: - plugin_ns, plugin_name = plugin_id.split("::") - return _generate_plugin_converter( - plugin_ns, - plugin_name, - capability_validator=capability_validator, - priority=priority, - supports_dynamic_shapes=supports_dynamic_shapes, - requires_output_allocator=requires_output_allocator, +def _generate_plugin_converter( + namespace: str, + op_name: str, + overload: Optional[str] = None, + capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None, + priority: ConverterPriority = ConverterPriority.STANDARD, + supports_dynamic_shapes: bool = False, + requires_output_allocator: bool = False, +) -> DynamoConverterImplSignature: + try: + import tensorrt.plugin as trtp + except ImportError as e: + raise RuntimeError( + "Unable to import TensorRT plugin. Please install TensorRT plugin library (https://github.com/NVIDIA/TensorRT-plugin-library?tab=readme-ov-file#installation) to add support for compiling quantized models" + ) + + torch_target = getattr(getattr(torch.ops, namespace), op_name) + overload_str = overload if overload else "" + overload_name = overload_str if overload else "default" + torch_overload = getattr(torch_target, overload_name) + assert ( + f"{namespace}::{op_name}" in QDP_REGISTRY + ), f"Could not find a tensorrt plugin registered for op {namespace}::{op_name}, unable to generate converter" + torch_schema = torch_target._schemas[overload_str] + + def custom_kernel_converter( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + + plugin = getattr(getattr(trtp.op, namespace), op_name) + + tensor_inputs = plugin.input_tensor_names + tensor_args = args[0 : len(tensor_inputs)] + + unique_id = uuid.uuid4() + itensor_args = [ + get_trt_tensor(ctx, t, f"{t_name}_{unique_id}") + for (t, t_name) in zip(tensor_args, tensor_inputs) + ] + + # Assuming TensorRT preserves kwargs order like PyTorch does + non_tensor_inputs = plugin.input_attrs + + kwargs = {} + + for arg in torch_schema.arguments: + if arg.default_value is not None: + kwargs[arg.name] = arg.default_value + + non_tensor_args = args[len(tensor_inputs) :] + non_tensor_kwargs = dict(zip(list(non_tensor_inputs.keys()), non_tensor_args)) + + for k, v in kwargs.items(): + if k in non_tensor_kwargs: + kwargs[k] = non_tensor_kwargs[k] + + for k, v in kwargs.items(): + if isinstance(v, torch.fx.immutable_collections.immutable_list): + kwargs[k] = np.array(v) + + layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs)) + assert layer, f"{namespace}::{name} plugin layer was not able to be created" + _LOGGER.debug( + f"Adding generated plugin for {namespace}::{name} to tensorrt network" ) + layer.name = f"[{target}]-[{name}]" + return layer.get_output(0) + + custom_kernel_converter = dynamo_tensorrt_converter( + torch_overload, + capability_validator=capability_validator, + priority=priority, + supports_dynamic_shapes=supports_dynamic_shapes, + requires_output_allocator=requires_output_allocator, + )(custom_kernel_converter) + assert ( + torch_overload in DYNAMO_CONVERTERS + ), f"Generated dynamo converter for {namespace}::{op_name} did not get properly registered in the converter registry" + return custom_kernel_converter + + +@needs_qdp_plugin +def generate_plugin_converter( + plugin_id: str, + capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None, + priority: ConverterPriority = ConverterPriority.STANDARD, + supports_dynamic_shapes: bool = False, + requires_output_allocator: bool = False, +) -> DynamoConverterImplSignature: + plugin_ns, plugin_name = plugin_id.split("::") + return _generate_plugin_converter( + plugin_ns, + plugin_name, + capability_validator=capability_validator, + priority=priority, + supports_dynamic_shapes=supports_dynamic_shapes, + requires_output_allocator=requires_output_allocator, + ) From dcf1ee62050e68d55d023b687bf5395523aeec71 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 20 May 2025 08:32:08 -0700 Subject: [PATCH 3/9] test --- .../dynamo/conversion/plugins/_generate_plugin_converter.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py index 112e6da957..242566d9c9 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py @@ -5,7 +5,7 @@ import numpy as np import tensorrt as trt import torch -from tensorrt.plugin._lib import QDP_REGISTRY + from torch.fx.node import Argument, Node, Target from torch_tensorrt._features import needs_qdp_plugin from torch_tensorrt.dynamo._settings import CompilationSettings @@ -32,11 +32,12 @@ def _generate_plugin_converter( ) -> DynamoConverterImplSignature: try: import tensorrt.plugin as trtp + except ImportError as e: raise RuntimeError( "Unable to import TensorRT plugin. Please install TensorRT plugin library (https://github.com/NVIDIA/TensorRT-plugin-library?tab=readme-ov-file#installation) to add support for compiling quantized models" ) - + from tensorrt.plugin._lib import QDP_REGISTRY torch_target = getattr(getattr(torch.ops, namespace), op_name) overload_str = overload if overload else "" overload_name = overload_str if overload else "default" From 8b82249ec2d86fba4418534f0f59a459d5c8f181 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 20 May 2025 09:11:00 -0700 Subject: [PATCH 4/9] resolve linting error --- .../dynamo/conversion/plugins/_generate_plugin_converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py index 242566d9c9..926a5779c1 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py @@ -5,7 +5,6 @@ import numpy as np import tensorrt as trt import torch - from torch.fx.node import Argument, Node, Target from torch_tensorrt._features import needs_qdp_plugin from torch_tensorrt.dynamo._settings import CompilationSettings @@ -38,6 +37,7 @@ def _generate_plugin_converter( "Unable to import TensorRT plugin. Please install TensorRT plugin library (https://github.com/NVIDIA/TensorRT-plugin-library?tab=readme-ov-file#installation) to add support for compiling quantized models" ) from tensorrt.plugin._lib import QDP_REGISTRY + torch_target = getattr(getattr(torch.ops, namespace), op_name) overload_str = overload if overload else "" overload_name = overload_str if overload else "default" From 6e5b8b4391f711b0e08683f0a5ac9fe50141e785 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 20 May 2025 10:17:49 -0700 Subject: [PATCH 5/9] fall back to old implementation for unsqeeze when tensorrt version < 10.7.0 --- .../dynamo/conversion/aten_ops_converters.py | 11 ++- .../dynamo/conversion/impl/unsqueeze.py | 67 +++++++++++++++++++ 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 1fed1f9a1f..26d1a1c157 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -661,7 +661,16 @@ def aten_ops_unsqueeze( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.unsqueeze.unsqueeze(ctx, target, SourceIR.ATEN, name, args[0], args[1]) + from importlib.metadata import version + + if version("tensorrt") >= "10.7.0": + return impl.unsqueeze.unsqueeze( + ctx, target, SourceIR.ATEN, name, args[0], args[1] + ) + else: + return impl.unsqueeze.unsqueeze_old( + ctx, target, SourceIR.ATEN, name, args[0], args[1] + ) @dynamo_tensorrt_converter( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py index 3dacc2fbe4..1fc83a0c2c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py @@ -4,6 +4,7 @@ from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( + get_positive_dim, get_trt_tensor, set_layer_name, ) @@ -24,6 +25,72 @@ def unsqueeze( return layer.get_output(0) +# old implementation for jetson due to IUnsqueezeLayer was not supported prior to 10.7.0 +def unsqueeze_old( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: int, +) -> TRTTensor: + input_val = get_trt_tensor(ctx, input, f"{name}_input") + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"unsqueeze received input {input_val} that is not part " + "of the TensorRT region!" + ) + + input_shape_size = len(input_val.shape) + dim = get_positive_dim(dim, input_shape_size + 1) + + intermediate_dim = 0 + dynamic_shape_cnt = 0 + # if unsqueeze the last dimensions, we can directly append to the shape + if dim == input_shape_size: + intermediate_dim = dim + else: + # since maximum of one dimension is permitted to be specified as -1 + # find the intermediate_dim which has only 1 dynamic_shape_cnt + # and then we can add a transpose after reshape if it is not the final shape we want + for i, s in reversed(list(enumerate(input_val.shape))): + if i >= dim: + if s == -1: + dynamic_shape_cnt += 1 + if dynamic_shape_cnt > 1: + intermediate_dim = i + 1 + break + if i == dim: + intermediate_dim = i + break + # calculate the new_shape for the shuffle layer's reshape_dims + new_shape = list( + tuple(input_val.shape)[:intermediate_dim] + + (1,) + + tuple(input_val.shape)[intermediate_dim:] + ) + for i, s in enumerate(new_shape): + if i < intermediate_dim and s == -1: + new_shape[i] = 0 + layer = ctx.net.add_shuffle(input_val) + layer.reshape_dims = tuple(new_shape) + # if the intermediate_dim is not the final dim we want to unsqueeze, add a second_transpose after reshape + if intermediate_dim != dim: + # calculate the second_transpose for the shuffle layer + permutation = [*range(0, len(new_shape))] + # for example: if the reshape_dims is (3, 3, 5, 1, 5) and the final shape we want is (3, 1, 3, 5, 5) + # here intermediate_dim=3, dim=1, we need to move intermediate_dim before [dim: intermediate_dim) + new_permutation = ( + tuple(permutation[:dim]) + + (intermediate_dim,) + + tuple(permutation[dim:intermediate_dim]) + + tuple(permutation[intermediate_dim + 1 :]) + ) + layer.second_transpose = new_permutation + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) + + def broadcast_in_dim( ctx: ConversionContext, target: Target, From 237344a09020218f3e06f978396723578266a803 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 20 May 2025 10:19:00 -0700 Subject: [PATCH 6/9] test --- py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py index 1fc83a0c2c..f66dd7ae81 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Sequence +from typing import List, Optional, Sequence, cast from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -41,6 +41,8 @@ def unsqueeze_old( "of the TensorRT region!" ) + dim = cast(int, dim) + input_shape_size = len(input_val.shape) dim = get_positive_dim(dim, input_shape_size + 1) From d865623e7ecbdc8bcf152cb2bd5df34b46fc8988 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 20 May 2025 10:50:53 -0700 Subject: [PATCH 7/9] address comments --- py/torch_tensorrt/_features.py | 4 +++- .../dynamo/conversion/plugins/_generate_plugin.py | 2 +- .../dynamo/conversion/plugins/_generate_plugin_converter.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/_features.py b/py/torch_tensorrt/_features.py index 09d1038cc6..e1cffb5c3a 100644 --- a/py/torch_tensorrt/_features.py +++ b/py/torch_tensorrt/_features.py @@ -83,7 +83,9 @@ def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: else: def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: - raise NotImplementedError("QDP Plugin is not available") + raise NotImplementedError( + "TensorRT QDP(Quick Deploy Plugins) not available, requires TensorRT 10.7.0 or higher" + ) return not_implemented(*args, **kwargs) diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py index 571e6afdcf..b41e1460f5 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py @@ -32,7 +32,7 @@ def _generate_plugin(plugin_name: str) -> None: import tensorrt.plugin as trtp except ImportError as e: raise RuntimeError( - "Unable to import TensorRT plugin. Please install TensorRT plugin library (https://github.com/NVIDIA/TensorRT-plugin-library?tab=readme-ov-file#installation) to add support for compiling quantized models" + "Unable to import TensorRT plugin. TensorRT version must be 10.7.0 or higher to support for Triton based TensorRT plugins" ) namespace, name = plugin_name.split("::") diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py index 926a5779c1..a16eaf7982 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py @@ -34,7 +34,7 @@ def _generate_plugin_converter( except ImportError as e: raise RuntimeError( - "Unable to import TensorRT plugin. Please install TensorRT plugin library (https://github.com/NVIDIA/TensorRT-plugin-library?tab=readme-ov-file#installation) to add support for compiling quantized models" + "Unable to import TensorRT plugin. TensorRT version must be 10.7.0 or higher to support for Triton based TensorRT plugins" ) from tensorrt.plugin._lib import QDP_REGISTRY From 186189890249dc40e63e83628a71a6ea1ac1aee5 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 20 May 2025 11:15:53 -0700 Subject: [PATCH 8/9] test --- .../dynamo/conversion/aten_ops_converters.py | 11 +---------- py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py | 4 ++++ 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 26d1a1c157..1fed1f9a1f 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -661,16 +661,7 @@ def aten_ops_unsqueeze( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - from importlib.metadata import version - - if version("tensorrt") >= "10.7.0": - return impl.unsqueeze.unsqueeze( - ctx, target, SourceIR.ATEN, name, args[0], args[1] - ) - else: - return impl.unsqueeze.unsqueeze_old( - ctx, target, SourceIR.ATEN, name, args[0], args[1] - ) + return impl.unsqueeze.unsqueeze(ctx, target, SourceIR.ATEN, name, args[0], args[1]) @dynamo_tensorrt_converter( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py index f66dd7ae81..00bf1d31f7 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py @@ -19,6 +19,10 @@ def unsqueeze( input: TRTTensor, dim: int, ) -> TRTTensor: + from importlib.metadata import version + + if version("tensorrt") < "10.7.0": + return unsqueeze_old(ctx, target, source_ir, name, input, dim) axes = get_trt_tensor(ctx, dim, f"{name}_axes") layer = ctx.net.add_unsqueeze(input, axes) set_layer_name(layer, target, name, source_ir) From 49b6710e97897b75d552ba4832a26f44a8d47be8 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 20 May 2025 14:08:17 -0700 Subject: [PATCH 9/9] test --- py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py index 00bf1d31f7..02ecf98bfe 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py @@ -1,3 +1,4 @@ +import logging from typing import List, Optional, Sequence, cast from torch.fx.node import Target @@ -10,6 +11,8 @@ ) from torch_tensorrt.dynamo.types import TRTTensor +logger = logging.getLogger(__name__) + def unsqueeze( ctx: ConversionContext, @@ -22,6 +25,9 @@ def unsqueeze( from importlib.metadata import version if version("tensorrt") < "10.7.0": + logger.warning( + f"IUnsqueezeLayer is supported starting from TensorRT 10.7.0, using the old unsqueeze implementation in the current TensorRT version: {version('tensorrt')}" + ) return unsqueeze_old(ctx, target, source_ir, name, input, dim) axes = get_trt_tensor(ctx, dim, f"{name}_axes") layer = ctx.net.add_unsqueeze(input, axes)