diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f31305568d..a7b91eec34 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,14 +21,14 @@ repos: - id: clang-format types_or: [c++, c, cuda] - repo: https://github.com/keith/pre-commit-buildifier - rev: 6.4.0 + rev: 8.0.3 hooks: - id: buildifier args: - --warnings=all - id: buildifier-lint - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.23 + rev: v0.24.1 hooks: - id: validate-pyproject - repo: https://github.com/pycqa/isort @@ -37,17 +37,17 @@ repos: - id: isort name: isort (python) - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v1.9.0" + rev: "v1.15.0" hooks: - id: mypy exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py" - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.3.3 + rev: v0.11.7 hooks: - id: ruff - repo: https://github.com/psf/black - rev: 24.3.0 + rev: 25.1.0 hooks: - id: black exclude: ^examples/custom_converters/elu_converter/setup.py|^docs @@ -57,7 +57,7 @@ repos: - id: typos - repo: https://github.com/astral-sh/uv-pre-commit # uv version. - rev: 0.5.5 + rev: 0.7.1 hooks: # Update the uv lockfile - id: uv-lock diff --git a/MODULE.bazel b/MODULE.bazel index 008c7f53fc..66a879afcf 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -94,9 +94,9 @@ http_archive( http_archive( name = "tensorrt", build_file = "@//third_party/tensorrt/archive:BUILD", - strip_prefix = "TensorRT-10.9.0.34", + strip_prefix = "TensorRT-10.10.0.31", urls = [ - "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.9.0/tars/TensorRT-10.9.0.34.Linux.x86_64-gnu.cuda-12.8.tar.gz", + "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.10.0/tars/TensorRT-10.10.0.31.Linux.x86_64-gnu.cuda-12.9.tar.gz", ], ) diff --git a/examples/dynamo/vgg16_ptq.py b/examples/dynamo/vgg16_ptq.py index 7fa943040e..7af490afb4 100644 --- a/examples/dynamo/vgg16_ptq.py +++ b/examples/dynamo/vgg16_ptq.py @@ -169,7 +169,6 @@ def vgg16(num_classes=1000, init_weights=False): data = iter(training_dataloader) images, _ = next(data) - crit = nn.CrossEntropyLoss() # %% @@ -200,8 +199,11 @@ def calibrate_loop(model): quant_cfg = mtq.INT8_DEFAULT_CFG elif args.quantize_type == "fp8": quant_cfg = mtq.FP8_DEFAULT_CFG +elif args.quantize_type == "fp4": + quant_cfg = mtq.NVFP4_DEFAULT_CFG # PTQ with in-place replacement to quantized modules mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has FP8 qdq nodes at this point # %% @@ -233,12 +235,20 @@ def calibrate_loop(model): with export_torch_mode(): # Compile the model with Torch-TensorRT Dynamo backend input_tensor = images.cuda() + torch.onnx.export(model, input_tensor, "mtq_vgg16_model.onnx") exp_program = torch.export.export(model, (input_tensor,), strict=False) if args.quantize_type == "int8": enabled_precisions = {torch.int8} elif args.quantize_type == "fp8": enabled_precisions = {torch.float8_e4m3fn} + elif args.quantize_type == "fp4": + enabled_precisions = { + torch.float4_e2m1fn_x2, + torch.float8_e4m3fn, + torch.float16, + torch.float32, + } trt_model = torchtrt.dynamo.compile( exp_program, inputs=[input_tensor], diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index c706c345d6..e0a78e1a0b 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -80,6 +80,12 @@ class dtype(Enum): :meta hide-value: """ + f4 = auto() + """4 bit floating-point number, equivalent to ``dtype.fp4`` and ``dtype.float4`` + + :meta hide-value: + """ + uint8 = u8 int8 = i8 @@ -91,6 +97,9 @@ class dtype(Enum): float8 = f8 fp8 = f8 + float4 = f4 + fp4 = f4 + half = f16 fp16 = f16 float16 = f16 @@ -162,6 +171,8 @@ def _from( return dtype.i32 elif t == torch.float8_e4m3fn: return dtype.f8 + elif t == torch.float4_e2m1fn_x2: + return dtype.f4 elif t == torch.half: return dtype.f16 elif t == torch.float: @@ -188,6 +199,8 @@ def _from( return dtype.i8 elif t == trt.DataType.FP8: return dtype.f8 + elif t == trt.DataType.FP4: + return dtype.fp4 elif t == trt.DataType.INT32: return dtype.i32 elif t == trt.DataType.INT64: @@ -357,6 +370,8 @@ def to( return torch.long elif self == dtype.f8: return torch.float8_e4m3fn + elif self == dtype.f4: + return torch.float4_e2m1fn_x2 elif self == dtype.f16: return torch.half elif self == dtype.f32: @@ -394,6 +409,8 @@ def to( return trt.DataType.BOOL elif self == dtype.bf16: return trt.DataType.BF16 + elif self == dtype.f4: + return trt.DataType.FP4 elif use_default: return trt.DataType.FLOAT else: @@ -410,6 +427,8 @@ def to( return np.int64 elif self == dtype.f16: return np.float16 + elif self == dtype.f4: + return np.float4_e2m1fn_x2 elif self == dtype.f32: return np.float32 elif self == dtype.f64: diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 831ce37305..4db11daa78 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -581,13 +581,13 @@ def compile( "\nThis feature is unimplemented in Torch-TRT Dynamo currently." ) - if use_explicit_typing: - if len(enabled_precisions) != 1 or not any( - x in enabled_precisions for x in {torch.float32, dtype.f32} - ): - raise AssertionError( - f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}" - ) + # if use_explicit_typing: + # if len(enabled_precisions) != 1 or not any( + # x in enabled_precisions for x in {torch.float32, dtype.f32} + # ): + # raise AssertionError( + # f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}" + # ) if use_fp32_acc: logger.debug( diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index aafd1072f4..921cb37646 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -29,7 +29,14 @@ REQUIRE_FULL_COMPILATION = False DRYRUN = False HARDWARE_COMPATIBLE = False -SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8} +SUPPORTED_KERNEL_PRECISIONS = { + dtype.f32, + dtype.f16, + dtype.bf16, + dtype.i8, + dtype.f8, + dtype.f4, +} TIMING_CACHE_PATH = os.path.join( tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin" ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 39a1ed957d..ecf08f38c4 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -274,13 +274,13 @@ def _populate_trt_builder_config( self.compilation_settings.dla_global_dram_size, ) - if dtype.float16 in self.compilation_settings.enabled_precisions: + if not self.compilation_settings.use_explicit_typing and dtype.float16 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.FP16) if dtype.int8 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.INT8) - if dtype.fp8 in self.compilation_settings.enabled_precisions: + if not self.compilation_settings.use_explicit_typing and dtype.fp8 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.FP8) if dtype.bfloat16 in self.compilation_settings.enabled_precisions: diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 1fed1f9a1f..08a1bb4ea4 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -617,6 +617,39 @@ def aten_ops_quantize_op( ) +try: + import modelopt.torch.quantization as mtq # noqa: F401 + + assert torch.ops.tensorrt.dynamic_block_quantize_op.default +except Exception as e: + _LOGGER.warning( + "Unable to import dynamic block quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling dynamic blockquantized models" + ) +else: + + @dynamo_tensorrt_converter(torch.ops.tensorrt.dynamic_block_quantize_op.default) + def aten_ops_dynamic_block_quantize_op( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.nvfp4_quantize.nvfp4_quantize( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[6], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True) def aten_ops_squeeze( diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 685f40b254..eb18a14eca 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -361,12 +361,37 @@ def create_constant( shape = list(torch_value.shape) if torch_value is not None: + if torch_value.dtype == torch.float8_e4m3fn: + weights = trt.Weights( + type=trt.DataType.FP8, + ptr=torch_value.data_ptr(), + count=torch_value.numel(), + ) + constant = ctx.net.add_constant( + shape, + weights, + ) + constant.name = name + return constant.get_output(0) + # Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8 + if torch_value.dtype == torch.uint8: + weights = trt.Weights( + type=trt.DataType.FP4, + ptr=torch_value.data_ptr(), + count=torch_value.numel() * 2, + ) + shape[-1] = shape[-1] * 2 + constant = ctx.net.add_constant( + shape, + weights, + ) + constant.name = name + return constant.get_output(0) if torch_value.dtype == torch.bfloat16: torch_value_fp32 = torch_value.to(torch.float32) numpy_value = torch_value_fp32.numpy() else: numpy_value = torch_value.numpy() - ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1) constant = ctx.net.add_constant( shape, @@ -381,7 +406,6 @@ def create_constant( trt.DataType.BF16, name + "_bf16_cast", ) - return constant.get_output(0) else: raise ValueError( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index df580b1516..1f2d9d0de1 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -14,6 +14,7 @@ matmul, nccl_ops, normalization, + nvfp4_quantize, pad, permutation, pool, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/addmm.py b/py/torch_tensorrt/dynamo/conversion/impl/addmm.py index 1a0690852a..73d98acfdf 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/addmm.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/addmm.py @@ -7,7 +7,7 @@ from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.types import TRTTensor - +import os def addmm( ctx: ConversionContext, @@ -21,6 +21,10 @@ def addmm( beta: Union[float, int], alpha: Union[float, int], ) -> TRTTensor: + if os.getenv("DISABLE_GEMM", "false").lower() == "true": + print("lan added disable_gemm is set, skip addmm and returning mat2") + return mat2 + print("lan added disable_gemm is not set, doing addmm") mm = impl.matmul.matrix_multiply(ctx, target, source_ir, f"{name}_mm", mat1, mat2) if alpha != 1: mm = impl.elementwise.mul( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py new file mode 100644 index 0000000000..1c2f297764 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py @@ -0,0 +1,368 @@ +import os +from typing import Optional, Union + +import numpy as np +import tensorrt as trt +import torch +import torch_tensorrt.dynamo.conversion.impl as impl +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + get_trt_tensor, + to_torch, +) +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTTensor + + +def nvfp4_quantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_tensor: TRTTensor, + block_size: int, + amax: Union[np.ndarray, torch.Tensor], + num_bits: int, + exponent_bits: int, + scale_num_bits: int, + scale_exponent_bits: int, +) -> TRTTensor: + """ + Adds quantize and dequantize ops (QDQ) which quantize to FP4 based + on the output_type set and dequantizes them back. + """ + print( + f"lan added nvfp4_quantize entered: {target=} {source_ir=} {name=} {input_tensor.shape=} {input_tensor.dtype=} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}" + ) + if len(input_tensor.shape) not in (2, 3): + raise ValueError( + f"nvfp4_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D" + ) + with unset_fake_temporarily(): + axis = -1 + global_scale = _calculate_global_scale(ctx, name, amax) + print(f"lan added input_tensor: {input_tensor.shape=} {input_tensor.dtype=}") + print(f"lan added global_scale: {global_scale.shape=} {global_scale.dtype=}") + if ".weight_quantizer" in name: + output = _static_double_quantize( + ctx, + target, + source_ir, + name, + input_tensor, + global_scale, + axis, + ) + elif ".input_quantizer" in name: + # quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 + output = _dynamic_double_quantize( + ctx, + target, + source_ir, + name, + input_tensor, + global_scale, + axis, + ) + + else: + raise ValueError( + f"quantizer received an input of {name}. Supported values: weight_quantizer | input_quantizer" + ) + return output + + +def _dynamic_double_quantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_tensor: torch.Tensor, + global_scale: torch.Tensor, + axis: int = -1, + block_size: int = 16, + output_type: trt.DataType = trt.DataType.FP4, + scale_type: trt.DataType = trt.DataType.FP8, +) -> TRTTensor: + """ + quantize input tensor to fp4 + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR] + name: str + input_tensor : Tensor (On GPU) + The input tensor. + global_scale : Tensor (On GPU) + The global per-tensor scaling factor. It should contain only 1 element. + axis : int + The axis to quantize. Default is -1 (the last axis). + block_size : int + The block size for quantization. Default is 16. + output_type : trt.DataType + The data type for quantized data. Default is FP4. + scale_type : trt.DataType + The data type for block scale. Default is FP8. + + """ + if os.getenv("DISABLE_DYNAMIC_QUANTIZE", "false").lower() == "true": + print("lan added disable_dynamic_quantize is set, skipping dynamic quantize") + return input_tensor + print("lan added disable_dynamic_quantize is not set, doing dynamic quantize") + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + + if input_tensor.dtype not in [trt.DataType.HALF, trt.DataType.FLOAT]: + raise ValueError( + f"Currently try float16, float32 only on input tensor for now. Unsupported dtype: {input_tensor.dtype}" + ) + # dynamic quantize input tensor to fp4 + dynamic_quantize_layer = ctx.net.add_dynamic_quantize( + input_tensor, + axis, + block_size, + output_type, + scale_type, + ) + dynamic_quantize_layer.set_input(1, global_scale) + set_layer_name( + dynamic_quantize_layer, target, name + "_dynamic_quantize", source_ir + ) + quantized_data_in_fp4 = dynamic_quantize_layer.get_output(0) + quantized_scale_in_fp8 = dynamic_quantize_layer.get_output(1) + + return _double_dequantize( + ctx, + target, + source_ir, + name, + quantized_data_in_fp4, + quantized_scale_in_fp8, + global_scale, + axis, + input_tensor.dtype, + ) + + +def _double_dequantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + quantized_data_in_fp4: TRTTensor, + quantized_scale_in_fp8: TRTTensor, + global_scale: torch.Tensor, + axis: int = -1, + output_type: trt.DataType = trt.DataType.FLOAT, +) -> TRTTensor: + # dequantize scale from fp8 to orignal dtype(default is float32) + dequantize_scale_layer = ctx.net.add_dequantize( + quantized_scale_in_fp8, global_scale, output_type + ) + dequantize_scale_layer.axis = axis + dequantize_scale_layer.to_type = output_type + set_layer_name( + dequantize_scale_layer, target, name + "_dequantize_scale", source_ir + ) + dequantized_scale = dequantize_scale_layer.get_output(0) + + # dequantize quantized_data_in_fp4 from fp4 to orignal dtype(default is float32) + dequantize_data_layer = ctx.net.add_dequantize( + quantized_data_in_fp4, dequantized_scale, output_type + ) + dequantize_data_layer.axis = axis + dequantize_data_layer.to_type = output_type + set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) + dequantized_data = dequantize_data_layer.get_output(0) + return dequantized_data + + +def _static_double_quantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor: torch.Tensor, + global_scale: torch.Tensor, + axis: int, +) -> TRTTensor: + """ + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor : Tensor (On GPU) + The input tensor for weights. + global_scale : Tensor (On GPU) + The global per-tensor scaling factor. It should contain only 1 element. + axis: int + The axis to quantize. Default is -1 (the last axis). + Returns: + quantized data tensor in fp4 + """ + if os.getenv("DISABLE_STATIC_QUANTIZE", "false").lower() == "true": + print("lan added disable_static_quantize is set, skipping static quantize") + return get_trt_tensor(ctx, weights_tensor, name + "_weights") + print( + "lan added static disable_static_quantize is not set, doing static double quantize " + ) + import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor + + if weights_tensor.dtype == torch.float16: + original_dtype = trt.DataType.HALF + elif weights_tensor.dtype == torch.float32: + original_dtype = trt.DataType.FLOAT + else: + raise ValueError( + f"Currently try float16, float32 only on weights tensor. Unsupported dtype: {weights_tensor.dtype}" + ) + block_scale_fp8 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( + weights_tensor, + 16, + global_scale, + )[0] + weights_tensor_fp4 = nvfp4_tensor.NVFP4QTensor.quantize( + weights_tensor, + 16, + block_scale_fp8, + global_scale, + )[0]._quantized_data + + block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8") + global_scale = to_torch(global_scale, None) + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + weights_tensor_fp4 = get_trt_tensor(ctx, weights_tensor_fp4, name + "_weights_fp4") + print( + f"lan added weights_tensor_fp4: {weights_tensor_fp4.shape=} {weights_tensor_fp4.dtype=}" + ) + dequantized_data = _double_dequantize( + ctx, + target, + source_ir, + name, + weights_tensor_fp4, + block_scale_fp8, + global_scale, + axis, + original_dtype, + ) + return dequantized_data + + +def _static_double_quantize_transpose( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor: torch.Tensor, + global_scale: torch.Tensor, + axis: int, +) -> TRTTensor: + """ + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor : Tensor (On GPU) + The input tensor for weights. + global_scale : Tensor (On GPU) + The global per-tensor scaling factor. It should contain only 1 element. + axis: int + The axis to quantize. Default is -1 (the last axis). + Returns: + quantized data tensor in fp4 + """ + axis = -2 + import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor + + if weights_tensor.dtype == torch.float16: + original_dtype = trt.DataType.HALF + elif weights_tensor.dtype == torch.float32: + original_dtype = trt.DataType.FLOAT + else: + raise ValueError( + f"Currently try float16, float32 only on weights tensor. Unsupported dtype: {weights_tensor.dtype}" + ) + block_scale = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( + weights_tensor, + 16, + global_scale, + keep_high_precision=True, + )[0] + weights_tensor_scaled = nvfp4_tensor.NVFP4QTensor.quantize( + weights_tensor, + 16, + block_scale, + global_scale, + keep_high_precision=True, + ) + + block_scale = block_scale.transpose(0, 1) + weights_tensor_scaled = weights_tensor_scaled.transpose(0, 1) + + block_scale_fp8 = block_scale.to(torch.float8_e4m3fn) + weights_tensor_uint4 = nvfp4_tensor.NVFP4QTensor._cast_fp4(weights_tensor_scaled) + weights_tensor_uint8 = ( + weights_tensor_uint4[..., 1::2] << 4 + ) | weights_tensor_uint4[..., 0::2] + + print( + f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}" + ) + print( + f"lan added weights_tensor_uint4: {weights_tensor_uint4.shape=} {weights_tensor_uint4.dtype=} {weights_tensor_uint4=}" + ) + print( + f"lan added weights_tensor_uint8: {weights_tensor_uint8.shape=} {weights_tensor_uint8.dtype=} {weights_tensor_uint8=}" + ) + + block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8") + global_scale = to_torch(global_scale, None) + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + weights_tensor_fp4 = get_trt_tensor( + ctx, weights_tensor_uint8, name + "_weights_fp4" + ) + print( + f"lan added weights_tensor_fp4: {weights_tensor_fp4.shape=} {weights_tensor_fp4.dtype=}" + ) + dequantized_data = _double_dequantize( + ctx, + target, + source_ir, + name, + weights_tensor_fp4, + block_scale_fp8, + global_scale, + axis, + original_dtype, + ) + dequantized_data = impl.permutation.permute( + ctx, + target, + source_ir, + name + "_dequantized_data_transposed", + dequantized_data, + (-1, -2), + ) + return dequantized_data + + +def _calculate_global_scale( + ctx: ConversionContext, + name: str, + amax: torch.Tensor, +) -> torch.Tensor: + # calculate global scale (the global per-tensor scaling factor, should only contain 1 element) + if amax is None or amax == 0: + amax = 1.0 + amax = to_torch( + amax, None + ) # amax is calculated from input_tensor.abs().amax().float() + global_scale = torch.divide(amax, 6 * 448) + if global_scale == 0: + global_scale = 1.0 + return global_scale diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index 1537d0fdbe..4408b62809 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -13,7 +13,7 @@ ) from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape from torch_tensorrt.fx.types import TRTTensor - +import os def permute( ctx: ConversionContext, @@ -23,6 +23,10 @@ def permute( input: TRTTensor, permutation: Sequence[int], ) -> TRTTensor: + if os.getenv("DISABLE_GEMM", "false").lower() == "true": + print("lan added disable_gemm is set, skip permute") + return input + print("lan added disable_gemm is not set, doing permute") if not isinstance(input, TRTTensor): raise RuntimeError( f"permute received input {input} that is not a TensorRT ITensor" diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 6ebefc5509..190b6752b4 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -101,4 +101,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # TODO: Update this function when quantization is added def is_impure(self, node: torch.fx.node.Node) -> bool: + if node.target in ( + torch.ops.tensorrt.quantize_op.default, + torch.ops.tensorrt.dynamic_block_quantize_op.default, + ): + return True return False diff --git a/pyproject.toml b/pyproject.toml index 3bb857e3e0..e2878bc7bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires = [ "cffi>=1.15.1", "typing-extensions>=4.7.0", "future>=0.18.3", - "tensorrt-cu12>=10.9.0,<10.10.0", + "tensorrt-cu12>=10.10.0,<10.11.0", "torch>=2.8.0.dev,<2.9.0", "pybind11==2.6.2", "numpy", @@ -56,10 +56,10 @@ keywords = [ ] dependencies = [ "torch>=2.8.0.dev,<2.9.0", - "tensorrt>=10.9.0,<10.10.0", - "tensorrt-cu12>=10.9.0,<10.10.0", - "tensorrt-cu12-bindings>=10.9.0,<10.10.0", - "tensorrt-cu12-libs>=10.9.0,<10.10.0", + "tensorrt>=10.10.0,<10.11.0", + "tensorrt-cu12>=10.10.0,<10.11.0", + "tensorrt-cu12-bindings>=10.10.0,<10.11.0", + "tensorrt-cu12-libs>=10.10.0,<10.11.0", "packaging>=23", "numpy", "typing-extensions>=4.7.0", diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 0c28b23bba..175d3d79d7 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -3,7 +3,6 @@ import platform import unittest from importlib import metadata - import pytest import timm import torch @@ -14,7 +13,7 @@ from packaging.version import Version assertions = unittest.TestCase() - +import os @pytest.mark.unit def test_resnet18(ir): @@ -199,6 +198,88 @@ def test_resnet18_half(ir): torch._dynamo.reset() +# @unittest.skipIf( +# torch.cuda.get_device_capability() < (10, 0), +# "FP4 quantization requires compute capability 10.0 or later", +# ) +@unittest.skipIf( + not importlib.util.find_spec("modelopt"), + "ModelOpt is required to run this test", +) +@pytest.mark.unit +def test_base_fp4(ir): + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.utils import export_torch_mode + dtype = torch.float16 + + class SimpleNetwork(torch.nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + self.linear1 = torch.nn.Linear( + in_features=64, out_features=32, bias=True, dtype=dtype + ) + + def forward(self, x): + x = self.linear1(x) + return x + + def calibrate_loop(model): + """Simple calibration function for testing.""" + model(input_tensor) + + input_tensor = torch.ones(128, 64, dtype=dtype).cuda() + + + model = SimpleNetwork().eval().cuda() + model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=dtype).cuda()) + model.linear1.bias = torch.nn.Parameter(torch.zeros(128, 32, dtype=dtype).cuda()) + print(f"lan added amax: {input_tensor.abs().amax()=}") + print(f"lan added amax: {model.linear1.weight.abs().amax()=}") + expected_output = model(input_tensor) + print(f"lan added model input: {input_tensor=}") + print(f"lan added model weight: {model.linear1.weight=}") + print(f"lan added model bias: {model.linear1.bias=}") + + quant_cfg = mtq.NVFP4_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has qdq nodes at this point + with torch.no_grad(): + with export_torch_mode(): + exp_program = torch.export.export(model, (input_tensor,), strict=False) + from torch.fx import passes + + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=[input_tensor], + enabled_precisions={ + torch.float4_e2m1fn_x2, + torch.float8_e4m3fn, + torch.float32, + torch.float16, + }, + min_block_size=1, + debug=True, + cache_built_engines=False, + reuse_cached_engines=False, + use_explicit_typing=dtype == torch.float16, + ) + + outputs_trt = trt_model(input_tensor) + if os.getenv("DISABLE_GEMM", "false").lower() == "true": + print("lan added disable_gemm is set, compring result with weights") + expected_output = model.linear1.weight + else: + print("lan added disable_gemm is not set, compring result with pytorch") + + print(f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}") + print(f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}") + + abs_diff = torch.abs(expected_output - outputs_trt) + print(f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}") + print(f"lan added abs_diff: {abs_diff=}") + assert torch.allclose(expected_output, outputs_trt, rtol=0.8, atol=0.8) + + @unittest.skipIf( torch.cuda.get_device_capability() < (8, 9), "FP8 quantization requires compute capability 8.9 or later", @@ -230,8 +311,8 @@ def calibrate_loop(model): input_tensor = torch.randn(1, 10).cuda() model = SimpleNetwork().eval().cuda() - quant_cfg = mtq.FP8_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) # model has FP8 qdq nodes at this point output_pyt = model(input_tensor)