From 1d172ced5c0af96204af6ecc19ba5c6d23a93570 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 29 Apr 2025 09:22:36 -0700 Subject: [PATCH 01/30] Add fp4 support --- examples/dynamo/vgg16_ptq.py | 4 ++ py/torch_tensorrt/_enums.py | 15 ++++++ .../dynamo/conversion/aten_ops_converters.py | 33 ++++++++++++ .../dynamo/conversion/impl/quantize.py | 52 ++++++++++++++++++ tests/py/dynamo/models/test_models_export.py | 54 +++++++++++++++++++ 5 files changed, 158 insertions(+) diff --git a/examples/dynamo/vgg16_ptq.py b/examples/dynamo/vgg16_ptq.py index 7fa943040e..0ed8772a44 100644 --- a/examples/dynamo/vgg16_ptq.py +++ b/examples/dynamo/vgg16_ptq.py @@ -200,6 +200,8 @@ def calibrate_loop(model): quant_cfg = mtq.INT8_DEFAULT_CFG elif args.quantize_type == "fp8": quant_cfg = mtq.FP8_DEFAULT_CFG +elif args.quantize_type == "fp4": + quant_cfg = mtq.NVFP4_DEFAULT_CFG # PTQ with in-place replacement to quantized modules mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) # model has FP8 qdq nodes at this point @@ -239,6 +241,8 @@ def calibrate_loop(model): enabled_precisions = {torch.int8} elif args.quantize_type == "fp8": enabled_precisions = {torch.float8_e4m3fn} + elif args.quantize_type == "fp4": + enabled_precisions = {torch.float4_e2m1fn_x2} trt_model = torchtrt.dynamo.compile( exp_program, inputs=[input_tensor], diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index c706c345d6..e45ed7375b 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -76,6 +76,12 @@ class dtype(Enum): f8 = auto() """8 bit floating-point number, equivalent to ``dtype.fp8`` and ``dtype.float8`` + + :meta hide-value: + """ + + f4 = auto() + """4 bit floating-point number, equivalent to ``dtype.fp4`` and ``dtype.float4`` :meta hide-value: """ @@ -90,6 +96,7 @@ class dtype(Enum): float8 = f8 fp8 = f8 + fp4 = f4 half = f16 fp16 = f16 @@ -162,6 +169,8 @@ def _from( return dtype.i32 elif t == torch.float8_e4m3fn: return dtype.f8 + elif t == torch.float4_e2m1fn_x2: + return dtype.f4 elif t == torch.half: return dtype.f16 elif t == torch.float: @@ -188,6 +197,8 @@ def _from( return dtype.i8 elif t == trt.DataType.FP8: return dtype.f8 + elif t == trt.DataType.FP4: + return dtype.fp4 elif t == trt.DataType.INT32: return dtype.i32 elif t == trt.DataType.INT64: @@ -357,6 +368,8 @@ def to( return torch.long elif self == dtype.f8: return torch.float8_e4m3fn + elif self == dtype.f4: + return torch.float4_e2m1fn_x2 elif self == dtype.f16: return torch.half elif self == dtype.f32: @@ -410,6 +423,8 @@ def to( return np.int64 elif self == dtype.f16: return np.float16 + elif self == dtype.f4: + return np.float4_e2m1fn_x2 elif self == dtype.f32: return np.float32 elif self == dtype.f64: diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 1fed1f9a1f..af5a783fcd 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -617,6 +617,39 @@ def aten_ops_quantize_op( ) +try: + import modelopt.torch.quantization as mtq # noqa: F401 + + assert torch.ops.tensorrt.dynamic_block_quantize_op.default +except Exception as e: + _LOGGER.warning( + "Unable to import dynamic block quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling dynamic blockquantized models" + ) +else: + + @dynamo_tensorrt_converter(torch.ops.tensorrt.dynamic_block_quantize_op.default) + def aten_ops_dynamic_block_quantize_op( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.quantize.dynamic_block_quantize( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[6], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True) def aten_ops_squeeze( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py index e472ed3092..d4ca457cc7 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -67,3 +67,55 @@ def quantize( dq_output = dequantize_layer.get_output(0) return dq_output + +def dynamic_block_quantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_tensor: TRTTensor, + block_size: int, + amax: Union[np.ndarray, torch.Tensor], + num_bits: int, + exponent_bits: int, + scale_num_bits: int, + scale_exponent_bits: int, +) -> TRTTensor: + """ + Adds quantize and dequantize ops (QDQ) which quantize to FP4 based + on the output_type set and dequantizes them back. + """ + + with unset_fake_temporarily(): + if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in ( + trt.float32, + trt.float16, + trt.bfloat16, + ): + raise ValueError( + f"dynamic_block_quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16 | bfloat16" + ) + if len(input_tensor.shape) not in (2, 3): + raise ValueError( + f"dynamic_block_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D" + ) + print(f"input_tensor.shape: {input_tensor.shape} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}") + max_bound = 6 + amax = to_torch(amax, None) + scale = torch.divide(amax, max_bound) + scale = get_trt_tensor(ctx, scale, name + "_scale") + + output_type=trt.DataType.FP4 + # Add Q node + dynamic_quantize_layer = ctx.net.add_dynamic_quantize(input_tensor, axis=-1, block_size=16, output_type=output_type) + quantize_layer.set_output_type(0, output_type) + + set_layer_name(quantize_layer, target, name + "_quantize", source_ir) + q_output = quantize_layer.get_output(0) + # Add DQ node + dequantize_layer = ctx.net.add_dequantize(q_output, scale) + set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir) + dequantize_layer.precision = output_type + dq_output = dequantize_layer.get_output(0) + + return dq_output diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 19fdeaa9ab..7cf9a7cc9f 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -199,6 +199,60 @@ def test_resnet18_half(ir): torch._dynamo.reset() + +@unittest.skipIf( + torch.cuda.get_device_capability() < (8, 9), + "FP4 quantization requires compute capability 8.9 or later", +) +@unittest.skipIf( + not importlib.util.find_spec("modelopt"), + "ModelOpt is required to run this test", +) +@pytest.mark.unit +def test_base_fp4(ir): + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.utils import export_torch_mode + + class SimpleNetwork(torch.nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + self.linear1 = torch.nn.Linear(in_features=10, out_features=5) + self.linear2 = torch.nn.Linear(in_features=5, out_features=1) + + def forward(self, x): + x = self.linear1(x) + x = torch.nn.ReLU()(x) + x = self.linear2(x) + return x + + def calibrate_loop(model): + """Simple calibration function for testing.""" + model(input_tensor) + + input_tensor = torch.randn(1, 10).cuda() + model = SimpleNetwork().eval().cuda() + + quant_cfg = mtq.NVFP4_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has FP8 qdq nodes at this point + output_pyt = model(input_tensor) + + with torch.no_grad(): + with export_torch_mode(): + exp_program = torch.export.export(model, (input_tensor,), strict=False) + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=[input_tensor], + enabled_precisions={torch.float4_e2m1fn_x2}, + min_block_size=1, + debug=True, + cache_built_engines=False, + reuse_cached_engines=False, + ) + outputs_trt = trt_model(input_tensor) + assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=5e-1) + + @unittest.skipIf( torch.cuda.get_device_capability() < (8, 9), "FP8 quantization requires compute capability 8.9 or later", From d2b1422b7f3d68c31cda4d19298ed7f94c16c0bc Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 30 Apr 2025 10:01:55 -0700 Subject: [PATCH 02/30] test --- .pre-commit-config.yaml | 12 +++++------ py/torch_tensorrt/_enums.py | 4 ++-- py/torch_tensorrt/dynamo/_defaults.py | 9 +++++++- .../dynamo/conversion/impl/quantize.py | 21 ++++++++++++------- tests/py/dynamo/models/test_models_export.py | 13 ++++++------ 5 files changed, 36 insertions(+), 23 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f31305568d..a7b91eec34 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,14 +21,14 @@ repos: - id: clang-format types_or: [c++, c, cuda] - repo: https://github.com/keith/pre-commit-buildifier - rev: 6.4.0 + rev: 8.0.3 hooks: - id: buildifier args: - --warnings=all - id: buildifier-lint - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.23 + rev: v0.24.1 hooks: - id: validate-pyproject - repo: https://github.com/pycqa/isort @@ -37,17 +37,17 @@ repos: - id: isort name: isort (python) - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v1.9.0" + rev: "v1.15.0" hooks: - id: mypy exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py" - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.3.3 + rev: v0.11.7 hooks: - id: ruff - repo: https://github.com/psf/black - rev: 24.3.0 + rev: 25.1.0 hooks: - id: black exclude: ^examples/custom_converters/elu_converter/setup.py|^docs @@ -57,7 +57,7 @@ repos: - id: typos - repo: https://github.com/astral-sh/uv-pre-commit # uv version. - rev: 0.5.5 + rev: 0.7.1 hooks: # Update the uv lockfile - id: uv-lock diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index e45ed7375b..2bda1c2868 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -76,10 +76,10 @@ class dtype(Enum): f8 = auto() """8 bit floating-point number, equivalent to ``dtype.fp8`` and ``dtype.float8`` - + :meta hide-value: """ - + f4 = auto() """4 bit floating-point number, equivalent to ``dtype.fp4`` and ``dtype.float4`` diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 379a196e2e..ded8adfb01 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -29,7 +29,14 @@ REQUIRE_FULL_COMPILATION = False DRYRUN = False HARDWARE_COMPATIBLE = False -SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8} +SUPPORTED_KERNEL_PRECISIONS = { + dtype.f32, + dtype.f16, + dtype.bf16, + dtype.i8, + dtype.f8, + dtype.f4, +} TIMING_CACHE_PATH = os.path.join( tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin" ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py index d4ca457cc7..8dd3ca501b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -68,6 +68,7 @@ def quantize( return dq_output + def dynamic_block_quantize( ctx: ConversionContext, target: Target, @@ -99,23 +100,29 @@ def dynamic_block_quantize( raise ValueError( f"dynamic_block_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D" ) - print(f"input_tensor.shape: {input_tensor.shape} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}") max_bound = 6 amax = to_torch(amax, None) scale = torch.divide(amax, max_bound) scale = get_trt_tensor(ctx, scale, name + "_scale") - output_type=trt.DataType.FP4 # Add Q node - dynamic_quantize_layer = ctx.net.add_dynamic_quantize(input_tensor, axis=-1, block_size=16, output_type=output_type) - quantize_layer.set_output_type(0, output_type) + dynamic_quantize_layer = ctx.net.add_dynamic_quantize( + input_tensor, + axis=-1, + block_size=16, + output_type=trt.DataType.FP4, + scale_type=trt.DataType.FP8, + ) + dynamic_quantize_layer.set_output_type(0, trt.DataType.FP4) - set_layer_name(quantize_layer, target, name + "_quantize", source_ir) - q_output = quantize_layer.get_output(0) + set_layer_name( + dynamic_quantize_layer, target, name + "_dynamic_quantize", source_ir + ) + q_output = dynamic_quantize_layer.get_output(0) # Add DQ node dequantize_layer = ctx.net.add_dequantize(q_output, scale) set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir) - dequantize_layer.precision = output_type + dequantize_layer.precision = trt.DataType.FP4 dq_output = dequantize_layer.get_output(0) return dq_output diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 7cf9a7cc9f..c6de05ac27 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -199,10 +199,9 @@ def test_resnet18_half(ir): torch._dynamo.reset() - @unittest.skipIf( - torch.cuda.get_device_capability() < (8, 9), - "FP4 quantization requires compute capability 8.9 or later", + torch.cuda.get_device_capability() < (10, 0), + "FP4 quantization requires compute capability 10.0 or later", ) @unittest.skipIf( not importlib.util.find_spec("modelopt"), @@ -216,8 +215,8 @@ def test_base_fp4(ir): class SimpleNetwork(torch.nn.Module): def __init__(self): super(SimpleNetwork, self).__init__() - self.linear1 = torch.nn.Linear(in_features=10, out_features=5) - self.linear2 = torch.nn.Linear(in_features=5, out_features=1) + self.linear1 = torch.nn.Linear(in_features=32, out_features=16) + self.linear2 = torch.nn.Linear(in_features=16, out_features=1) def forward(self, x): x = self.linear1(x) @@ -229,12 +228,12 @@ def calibrate_loop(model): """Simple calibration function for testing.""" model(input_tensor) - input_tensor = torch.randn(1, 10).cuda() + input_tensor = torch.randn(1, 32).cuda() model = SimpleNetwork().eval().cuda() quant_cfg = mtq.NVFP4_DEFAULT_CFG mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - # model has FP8 qdq nodes at this point + # model has FP4 qdq nodes at this point output_pyt = model(input_tensor) with torch.no_grad(): From d439d969bd4aa971117d3c2befc7a8489306456d Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 1 May 2025 16:38:37 -0700 Subject: [PATCH 03/30] upgrade modelopt --- py/torch_tensorrt/_enums.py | 2 ++ pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index 2bda1c2868..6d3250c681 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -96,6 +96,8 @@ class dtype(Enum): float8 = f8 fp8 = f8 + + float4 = f4 fp4 = f4 half = f16 diff --git a/pyproject.toml b/pyproject.toml index 87c70fec17..3bb857e3e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,7 @@ dev = [ torchvision = [ "torchvision", ] #Leaving torchvisions dependency unconstrained so uv can just install something that should work for the torch we have. TV's on PyT makes it hard to put version constrains in -quantization = ["nvidia-modelopt[deploy,hf,torch]>=0.17.0"] +quantization = ["nvidia-modelopt[all]>=0.27.1"] monitoring-tools = ["rich>=13.7.1"] jupyter = ["rich[jupyter]>=13.7.1"] distributed = ["tensorrt-llm>=0.16.0"] From 5a2213ed25386994198088cbfa86343d1379ea7c Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 1 May 2025 20:45:33 -0700 Subject: [PATCH 04/30] add constant fold --- py/torch_tensorrt/dynamo/conversion/impl/quantize.py | 9 ++++----- .../dynamo/lowering/passes/constant_folding.py | 5 +++++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py index 8dd3ca501b..5881934803 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -108,12 +108,11 @@ def dynamic_block_quantize( # Add Q node dynamic_quantize_layer = ctx.net.add_dynamic_quantize( input_tensor, - axis=-1, - block_size=16, - output_type=trt.DataType.FP4, - scale_type=trt.DataType.FP8, + -1, + 16, + trt.DataType.FP4, + trt.DataType.FP8, ) - dynamic_quantize_layer.set_output_type(0, trt.DataType.FP4) set_layer_name( dynamic_quantize_layer, target, name + "_dynamic_quantize", source_ir diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 6ebefc5509..190b6752b4 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -101,4 +101,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # TODO: Update this function when quantization is added def is_impure(self, node: torch.fx.node.Node) -> bool: + if node.target in ( + torch.ops.tensorrt.quantize_op.default, + torch.ops.tensorrt.dynamic_block_quantize_op.default, + ): + return True return False From fcf0c123a061b66ce94db6fc84655f916121cb20 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 1 May 2025 21:47:32 -0700 Subject: [PATCH 05/30] fix the input tensor type issue --- py/torch_tensorrt/dynamo/conversion/impl/quantize.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py index 5881934803..63486e9eb4 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -28,6 +28,8 @@ def quantize( """ with unset_fake_temporarily(): + if not isinstance(input_tensor, TRTTensor): + input_tensor = get_trt_tensor(ctx, input_tensor, name + "_quantize_input") if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in ( trt.float32, trt.float16, @@ -88,6 +90,10 @@ def dynamic_block_quantize( """ with unset_fake_temporarily(): + if not isinstance(input_tensor, TRTTensor): + input_tensor = get_trt_tensor( + ctx, input_tensor, name + "_dynamic_quantize_input" + ) if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in ( trt.float32, trt.float16, From 057f35afc8253432ad83e1255df43c0d8a862ba0 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 5 May 2025 13:30:27 -0700 Subject: [PATCH 06/30] test --- examples/dynamo/vgg16_ptq.py | 4 --- .../dynamo/conversion/impl/quantize.py | 28 +++++++++++++------ tests/py/dynamo/models/test_models_export.py | 16 ++++------- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/examples/dynamo/vgg16_ptq.py b/examples/dynamo/vgg16_ptq.py index 0ed8772a44..7fa943040e 100644 --- a/examples/dynamo/vgg16_ptq.py +++ b/examples/dynamo/vgg16_ptq.py @@ -200,8 +200,6 @@ def calibrate_loop(model): quant_cfg = mtq.INT8_DEFAULT_CFG elif args.quantize_type == "fp8": quant_cfg = mtq.FP8_DEFAULT_CFG -elif args.quantize_type == "fp4": - quant_cfg = mtq.NVFP4_DEFAULT_CFG # PTQ with in-place replacement to quantized modules mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) # model has FP8 qdq nodes at this point @@ -241,8 +239,6 @@ def calibrate_loop(model): enabled_precisions = {torch.int8} elif args.quantize_type == "fp8": enabled_precisions = {torch.float8_e4m3fn} - elif args.quantize_type == "fp4": - enabled_precisions = {torch.float4_e2m1fn_x2} trt_model = torchtrt.dynamo.compile( exp_program, inputs=[input_tensor], diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py index 63486e9eb4..2961d58247 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -88,7 +88,9 @@ def dynamic_block_quantize( Adds quantize and dequantize ops (QDQ) which quantize to FP4 based on the output_type set and dequantizes them back. """ - + print( + f"dynamic_block_quantize entered: {target=} {source_ir=} {name=} {input_tensor.shape=} {input_tensor.dtype=} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}" + ) with unset_fake_temporarily(): if not isinstance(input_tensor, TRTTensor): input_tensor = get_trt_tensor( @@ -114,18 +116,28 @@ def dynamic_block_quantize( # Add Q node dynamic_quantize_layer = ctx.net.add_dynamic_quantize( input_tensor, - -1, - 16, - trt.DataType.FP4, - trt.DataType.FP8, + axis=1, + block_size=16, + output_type=trt.DataType.FP4, + scale_type=trt.DataType.FP8, ) - set_layer_name( dynamic_quantize_layer, target, name + "_dynamic_quantize", source_ir ) q_output = dynamic_quantize_layer.get_output(0) - # Add DQ node - dequantize_layer = ctx.net.add_dequantize(q_output, scale) + q_scale = dynamic_quantize_layer.get_output(1) + + # Add double DQ node + scale_dequantize_layer = ctx.net.add_dequantize(q_scale, scale) + scale_dequantize_layer.axis = 0 + set_layer_name( + scale_dequantize_layer, target, name + "_scale_dequantize", source_ir + ) + scale_dequantize_layer.precision = trt.DataType.FP8 + scale_dq_output = scale_dequantize_layer.get_output(0) + + dequantize_layer = ctx.net.add_dequantize(q_output, scale_dq_output) + dequantize_layer.axis = 1 set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir) dequantize_layer.precision = trt.DataType.FP4 dq_output = dequantize_layer.get_output(0) diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 381aac312e..c7985c5cdb 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -199,10 +199,10 @@ def test_resnet18_half(ir): torch._dynamo.reset() -@unittest.skipIf( - torch.cuda.get_device_capability() < (10, 0), - "FP4 quantization requires compute capability 10.0 or later", -) +# @unittest.skipIf( +# torch.cuda.get_device_capability() < (10, 0), +# "FP4 quantization requires compute capability 10.0 or later", +# ) @unittest.skipIf( not importlib.util.find_spec("modelopt"), "ModelOpt is required to run this test", @@ -215,20 +215,17 @@ def test_base_fp4(ir): class SimpleNetwork(torch.nn.Module): def __init__(self): super(SimpleNetwork, self).__init__() - self.linear1 = torch.nn.Linear(in_features=32, out_features=16) - self.linear2 = torch.nn.Linear(in_features=16, out_features=1) + self.linear1 = torch.nn.Linear(in_features=16, out_features=5) def forward(self, x): x = self.linear1(x) - x = torch.nn.ReLU()(x) - x = self.linear2(x) return x def calibrate_loop(model): """Simple calibration function for testing.""" model(input_tensor) - input_tensor = torch.randn(1, 32).cuda() + input_tensor = torch.randn(1, 16).cuda() model = SimpleNetwork().eval().cuda() quant_cfg = mtq.NVFP4_DEFAULT_CFG @@ -283,7 +280,6 @@ def calibrate_loop(model): input_tensor = torch.randn(1, 10).cuda() model = SimpleNetwork().eval().cuda() - quant_cfg = mtq.FP8_DEFAULT_CFG mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) # model has FP8 qdq nodes at this point From 7b09862bb347d1820cad429d738fdec6fc3a0acd Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 6 May 2025 12:32:47 -0700 Subject: [PATCH 07/30] test --- .../dynamo/conversion/impl/quantize.py | 215 +++++++++++++++--- tests/py/dynamo/models/test_models_export.py | 2 +- 2 files changed, 179 insertions(+), 38 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py index 2961d58247..7939872e0e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Tuple, Union import numpy as np import tensorrt as trt @@ -93,9 +93,7 @@ def dynamic_block_quantize( ) with unset_fake_temporarily(): if not isinstance(input_tensor, TRTTensor): - input_tensor = get_trt_tensor( - ctx, input_tensor, name + "_dynamic_quantize_input" - ) + input_tensor = get_trt_tensor(ctx, input_tensor, name + "_input") if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in ( trt.float32, trt.float16, @@ -104,42 +102,185 @@ def dynamic_block_quantize( raise ValueError( f"dynamic_block_quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16 | bfloat16" ) - if len(input_tensor.shape) not in (2, 3): - raise ValueError( - f"dynamic_block_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D" - ) + + # calculate global scale (the global per-tensor scaling factor, should only contain 1 element) max_bound = 6 amax = to_torch(amax, None) - scale = torch.divide(amax, max_bound) - scale = get_trt_tensor(ctx, scale, name + "_scale") + global_scale = torch.divide(amax, max_bound) + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") - # Add Q node - dynamic_quantize_layer = ctx.net.add_dynamic_quantize( - input_tensor, - axis=1, - block_size=16, - output_type=trt.DataType.FP4, - scale_type=trt.DataType.FP8, - ) - set_layer_name( - dynamic_quantize_layer, target, name + "_dynamic_quantize", source_ir - ) - q_output = dynamic_quantize_layer.get_output(0) - q_scale = dynamic_quantize_layer.get_output(1) - - # Add double DQ node - scale_dequantize_layer = ctx.net.add_dequantize(q_scale, scale) - scale_dequantize_layer.axis = 0 - set_layer_name( - scale_dequantize_layer, target, name + "_scale_dequantize", source_ir + if ".weight_quantizer" in name: + # static double quantization is used for weights + q_output, q_scale = _static_double_quantize( + ctx, + target, + source_ir, + name, + input_tensor, + global_scale, + ) + output = _block_double_dequantize( + ctx, + target, + source_ir, + name, + q_output, + q_scale, + global_scale, + ) + elif ".input_quantizer" in name: + # dynamic double quantization is used for inputs + # Add DYQ node + q_output, q_scale = _dynamic_quantize( + ctx, + target, + source_ir, + name, + input_tensor, + global_scale, + ) + # Add double DQ node + output = _block_double_dequantize( + ctx, + target, + source_ir, + name, + q_output, + q_scale, + global_scale, + ) + else: + raise ValueError( + f"dynamic_block_quantize converter received an input of {name} name. Supported names: weight_quantizer | input_quantizer" + ) + return output + + +def _dynamic_quantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_tensor: TRTTensor, + global_scale: TRTTensor, + axis: int = -1, + block_size: int = 16, + output_type: trt.DataType = trt.DataType.FP4, + scale_type: trt.DataType = trt.DataType.FP8, +) -> Tuple[TRTTensor, TRTTensor]: + """ + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR] + name: str + input_tensor : Tensor (On GPU) + The input tensor. + global_scale : Tensor (On GPU) + The global per-tensor scaling factor. It should contain only 1 element. + axis : int + The axis to quantize. Default is -1 (the last axis). + block_size : int + The block size for quantization. Default is 16. + data_qtype : trt.DataType + The data type for quantized data. Default is FP4. + scale_qtype : trt.DataType + The data type for block scale. Default is FP8. + Returns: + A tuple of two tensors: quantized tensor in f4 and block scale tensor. + """ + if len(input_tensor.shape) not in (2, 3): + raise ValueError( + f"dynamic_block_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D" ) - scale_dequantize_layer.precision = trt.DataType.FP8 - scale_dq_output = scale_dequantize_layer.get_output(0) + if axis < 0: + axis = len(input_tensor.shape) + axis + # Add DYQ node + dynamic_quantize_layer = ctx.net.add_dynamic_quantize( + input_tensor, + axis, + block_size, + output_type, + scale_type, + ) + dynamic_quantize_layer.set_input(1, global_scale) + set_layer_name( + dynamic_quantize_layer, target, name + "_dynamic_quantize", source_ir + ) + q_output = dynamic_quantize_layer.get_output(0) + q_scale = dynamic_quantize_layer.get_output(1) + return q_output, q_scale - dequantize_layer = ctx.net.add_dequantize(q_output, scale_dq_output) - dequantize_layer.axis = 1 - set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir) - dequantize_layer.precision = trt.DataType.FP4 - dq_output = dequantize_layer.get_output(0) - return dq_output +def _block_double_dequantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_tensor: TRTTensor, + scale: TRTTensor, + global_scale: TRTTensor, + dtype: trt.DataType = trt.DataType.FLOAT, +) -> TRTTensor: + """ + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR] + name: str + input_tensor : Tensor (On GPU) + The input tensor. + scale : Tensor (On GPU) + The block scale tensor. + global_scale : Tensor (On GPU) + The global per-tensor scaling factor. It should contain only 1 element. + dtype : trt.DataType | str + The data type for dequantized data. Default is float32. + Returns: + The dequantized tensor. + """ + # dequantize scale from fp8 to dtype(default is float32) + dequantize_scale_layer = ctx.net.add_dequantize(scale, global_scale, dtype) + set_layer_name( + dequantize_scale_layer, target, name + "_dequantize_scale", source_ir + ) + dequantized_scale = dequantize_scale_layer.get_output(0) + + # dequantize input_tensor from fp4 to dtype(default is float32) + dequantize_data_layer = ctx.net.add_dequantize( + input_tensor, dequantized_scale, dtype + ) + set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) + dq_output = dequantize_data_layer.get_output(0) + return dq_output + + +def _static_double_quantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_tensor: TRTTensor, + global_scale: TRTTensor, +) -> Tuple[TRTTensor, TRTTensor]: + """ + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_tensor : Tensor (On GPU) + The input tensor. + global_scale : Tensor (On GPU) + The global per-tensor scaling factor. It should contain only 1 element. + Returns: + A tuple of two tensors: quantized tensor and scaling factor tensor + """ + pass + return input_tensor, global_scale + # quantize_layer = ctx.net.add_quantize(input_tensor, global_scale) + # set_layer_name(quantize_layer, target, name + "_quantize", source_ir) + # q_output = quantize_layer.get_output(0) + # q_scale = quantize_layer.get_output(1) + + # return q_output, q_scale diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index c7985c5cdb..81f902440b 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -230,7 +230,7 @@ def calibrate_loop(model): quant_cfg = mtq.NVFP4_DEFAULT_CFG mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - # model has FP4 qdq nodes at this point + # model has qdq nodes at this point output_pyt = model(input_tensor) with torch.no_grad(): From d9f2ad9a2fb261a36b9c9e11e6f93c7088f31715 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 6 May 2025 16:17:35 -0700 Subject: [PATCH 08/30] test --- .../dynamo/conversion/aten_ops_converters.py | 2 +- .../dynamo/conversion/impl/quantize.py | 151 ++++++++++++------ tests/py/dynamo/models/test_models_export.py | 5 +- 3 files changed, 104 insertions(+), 54 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index af5a783fcd..fb02c975d1 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -635,7 +635,7 @@ def aten_ops_dynamic_block_quantize_op( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.quantize.dynamic_block_quantize( + return impl.quantize.nvfp4_quantize( ctx, target, SourceIR.ATEN, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py index 7939872e0e..9953ffed01 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -71,7 +71,7 @@ def quantize( return dq_output -def dynamic_block_quantize( +def nvfp4_quantize( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], @@ -88,50 +88,71 @@ def dynamic_block_quantize( Adds quantize and dequantize ops (QDQ) which quantize to FP4 based on the output_type set and dequantizes them back. """ - print( - f"dynamic_block_quantize entered: {target=} {source_ir=} {name=} {input_tensor.shape=} {input_tensor.dtype=} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}" - ) + with unset_fake_temporarily(): - if not isinstance(input_tensor, TRTTensor): - input_tensor = get_trt_tensor(ctx, input_tensor, name + "_input") - if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in ( - trt.float32, - trt.float16, - trt.bfloat16, - ): - raise ValueError( - f"dynamic_block_quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16 | bfloat16" - ) + if ".weight_quantizer" in name: + # calculate global scale (the global per-tensor scaling factor, should only contain 1 element) + amax = to_torch( + amax, None + ) # amax is calculated from input_tensor.abs().amax().float() + global_scale = torch.divide(amax, 6) - # calculate global scale (the global per-tensor scaling factor, should only contain 1 element) - max_bound = 6 - amax = to_torch(amax, None) - global_scale = torch.divide(amax, max_bound) - global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + # calculate block scaling factor of weights + [n, k] = input_tensor.shape[-2:] + assert block_size != 0, "block_size must be non-zero" + assert k % block_size == 0, "k must be a multiple of block_size" + reshaped_input_tensor = input_tensor.reshape( + tuple(input_tensor.shape[:-2]) + (n, k // block_size, block_size) + ) + per_block_amax = reshaped_input_tensor.abs().amax(dim=-1).float() + per_block_scale = torch.divide(per_block_amax, 6) - if ".weight_quantizer" in name: + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + per_block_scale = get_trt_tensor( + ctx, per_block_scale, name + "_per_block_scale" + ) + input_tensor = get_trt_tensor(ctx, input_tensor, name + "_input") # static double quantization is used for weights - q_output, q_scale = _static_double_quantize( - ctx, - target, - source_ir, - name, - input_tensor, - global_scale, + quantized_data_in_fp4, quantized_block_scale_in_fp8 = ( + _static_double_quantize( + ctx, + target, + source_ir, + name, + input_tensor, + per_block_scale, + global_scale, + ) ) output = _block_double_dequantize( ctx, target, source_ir, name, - q_output, - q_scale, + quantized_data_in_fp4, + quantized_block_scale_in_fp8, global_scale, ) elif ".input_quantizer" in name: + if not isinstance(input_tensor, TRTTensor): + input_tensor = get_trt_tensor(ctx, input_tensor, name + "_input") + if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in ( + trt.float32, + trt.float16, + trt.bfloat16, + ): + raise ValueError( + f"dynamic_block_quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16 | bfloat16" + ) + # dynamic double quantization is used for inputs - # Add DYQ node - q_output, q_scale = _dynamic_quantize( + # calculate global scale (the global per-tensor scaling factor, should only contain 1 element) + amax = to_torch(amax, None) + global_scale = torch.divide(amax, 6) + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + + # quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 + quantized_data_in_fp4, quantized_scale_in_fp8 = _dynamic_quantize( ctx, target, source_ir, @@ -145,8 +166,8 @@ def dynamic_block_quantize( target, source_ir, name, - q_output, - q_scale, + quantized_data_in_fp4, + quantized_scale_in_fp8, global_scale, ) else: @@ -169,6 +190,7 @@ def _dynamic_quantize( scale_type: trt.DataType = trt.DataType.FP8, ) -> Tuple[TRTTensor, TRTTensor]: """ + quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 Parameters: ctx: ConversionContext, target: Target, @@ -187,7 +209,7 @@ def _dynamic_quantize( scale_qtype : trt.DataType The data type for block scale. Default is FP8. Returns: - A tuple of two tensors: quantized tensor in f4 and block scale tensor. + A tuple of two tensors: quantized data tensor in fp4 and quantized scale tensor in fp8. """ if len(input_tensor.shape) not in (2, 3): raise ValueError( @@ -207,9 +229,9 @@ def _dynamic_quantize( set_layer_name( dynamic_quantize_layer, target, name + "_dynamic_quantize", source_ir ) - q_output = dynamic_quantize_layer.get_output(0) - q_scale = dynamic_quantize_layer.get_output(1) - return q_output, q_scale + quantized_data_in_fp4 = dynamic_quantize_layer.get_output(0) + quantized_scale_in_fp8 = dynamic_quantize_layer.get_output(1) + return quantized_data_in_fp4, quantized_scale_in_fp8 def _block_double_dequantize( @@ -223,9 +245,10 @@ def _block_double_dequantize( dtype: trt.DataType = trt.DataType.FLOAT, ) -> TRTTensor: """ - Parameters: - ctx: ConversionContext, - target: Target, + dequantize input_tensor from fp4 to dtype(default is float32) + Parameters: + ctx: ConversionContext, + target: Target, source_ir: Optional[SourceIR] name: str input_tensor : Tensor (On GPU) @@ -261,26 +284,50 @@ def _static_double_quantize( source_ir: Optional[SourceIR], name: str, input_tensor: TRTTensor, + per_block_scale: TRTTensor, global_scale: TRTTensor, ) -> Tuple[TRTTensor, TRTTensor]: """ Parameters: - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, input_tensor : Tensor (On GPU) The input tensor. + per_block_scale : Tensor (On GPU) + The per-block scaling factor. global_scale : Tensor (On GPU) The global per-tensor scaling factor. It should contain only 1 element. Returns: - A tuple of two tensors: quantized tensor and scaling factor tensor + A tuple of two tensors: quantized data tensor in fp4 and quantized block scaling factor tensor in fp8 """ - pass - return input_tensor, global_scale - # quantize_layer = ctx.net.add_quantize(input_tensor, global_scale) - # set_layer_name(quantize_layer, target, name + "_quantize", source_ir) - # q_output = quantize_layer.get_output(0) - # q_scale = quantize_layer.get_output(1) - # return q_output, q_scale + block_scale_quantize_layer = ctx.net.add_quantize(per_block_scale, global_scale) + set_layer_name( + block_scale_quantize_layer, + target, + name + "_per_block_scale_quantize", + source_ir, + ) + block_scale_quantize_layer.set_output_type(0, trt.DataType.FP8) + quantized_block_scale_in_fp8 = block_scale_quantize_layer.get_output(0) + + dequantize_block_scale_layer = ctx.net.add_dequantize( + quantized_block_scale_in_fp8, global_scale + ) + set_layer_name( + dequantize_block_scale_layer, + target, + name + "_dequantize_block_scale", + source_ir, + ) + dequantize_block_scale_layer.precision = trt.DataType.FP8 + dequantized_block_scale = dequantize_block_scale_layer.get_output(0) + + data_quantize_layer = ctx.net.add_quantize(input_tensor, dequantized_block_scale) + set_layer_name(data_quantize_layer, target, name + "_data_quantize", source_ir) + data_quantize_layer.set_output_type(0, trt.DataType.FP4) + quantized_data_in_fp4 = data_quantize_layer.get_output(0) + + return quantized_data_in_fp4, quantized_block_scale_in_fp8 diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 81f902440b..e9c128ee40 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -225,7 +225,8 @@ def calibrate_loop(model): """Simple calibration function for testing.""" model(input_tensor) - input_tensor = torch.randn(1, 16).cuda() + input_tensor = torch.randn(5, 16).cuda() + print(f"amax: {input_tensor.abs().amax()}") model = SimpleNetwork().eval().cuda() quant_cfg = mtq.NVFP4_DEFAULT_CFG @@ -246,6 +247,8 @@ def calibrate_loop(model): reuse_cached_engines=False, ) outputs_trt = trt_model(input_tensor) + print(f"lan added outputs_trt: {outputs_trt}") + print(f"lan added output_pyt: {output_pyt}") assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=5e-1) From 6892a4749146e127ac79db935f7589575af9569a Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 6 May 2025 16:41:58 -0700 Subject: [PATCH 09/30] test --- .../dynamo/conversion/impl/quantize.py | 54 +++++++++---------- tests/py/dynamo/models/test_models_export.py | 12 ++--- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py index 9953ffed01..cce91f2725 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -88,15 +88,30 @@ def nvfp4_quantize( Adds quantize and dequantize ops (QDQ) which quantize to FP4 based on the output_type set and dequantizes them back. """ - + print( + f"lan added nvfp4_quantize entered: {target=} {source_ir=} {name=} {input_tensor.shape=} {input_tensor.dtype=} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}" + ) with unset_fake_temporarily(): - if ".weight_quantizer" in name: - # calculate global scale (the global per-tensor scaling factor, should only contain 1 element) - amax = to_torch( - amax, None - ) # amax is calculated from input_tensor.abs().amax().float() - global_scale = torch.divide(amax, 6) + if not isinstance(input_tensor, TRTTensor): + input_tensor = get_trt_tensor(ctx, input_tensor, name + "_input") + if input_tensor.dtype not in ( + trt.float32, + trt.float16, + trt.bfloat16, + ): + raise ValueError( + f"dynamic_block_quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16 | bfloat16" + ) + # TODO: ADD PADDING IF + + # calculate global scale (the global per-tensor scaling factor, should only contain 1 element) + amax = to_torch( + amax, None + ) # amax is calculated from input_tensor.abs().amax().float() + global_scale = torch.divide(amax, 6) + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + if ".weight_quantizer" in name: # calculate block scaling factor of weights [n, k] = input_tensor.shape[-2:] assert block_size != 0, "block_size must be non-zero" @@ -107,11 +122,10 @@ def nvfp4_quantize( per_block_amax = reshaped_input_tensor.abs().amax(dim=-1).float() per_block_scale = torch.divide(per_block_amax, 6) - global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") per_block_scale = get_trt_tensor( ctx, per_block_scale, name + "_per_block_scale" ) - input_tensor = get_trt_tensor(ctx, input_tensor, name + "_input") + # static double quantization is used for weights quantized_data_in_fp4, quantized_block_scale_in_fp8 = ( _static_double_quantize( @@ -134,23 +148,6 @@ def nvfp4_quantize( global_scale, ) elif ".input_quantizer" in name: - if not isinstance(input_tensor, TRTTensor): - input_tensor = get_trt_tensor(ctx, input_tensor, name + "_input") - if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in ( - trt.float32, - trt.float16, - trt.bfloat16, - ): - raise ValueError( - f"dynamic_block_quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16 | bfloat16" - ) - - # dynamic double quantization is used for inputs - # calculate global scale (the global per-tensor scaling factor, should only contain 1 element) - amax = to_torch(amax, None) - global_scale = torch.divide(amax, 6) - global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") - # quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 quantized_data_in_fp4, quantized_scale_in_fp8 = _dynamic_quantize( ctx, @@ -169,6 +166,7 @@ def nvfp4_quantize( quantized_data_in_fp4, quantized_scale_in_fp8, global_scale, + input_tensor.dtype, ) else: raise ValueError( @@ -314,7 +312,9 @@ def _static_double_quantize( quantized_block_scale_in_fp8 = block_scale_quantize_layer.get_output(0) dequantize_block_scale_layer = ctx.net.add_dequantize( - quantized_block_scale_in_fp8, global_scale + quantized_block_scale_in_fp8, + global_scale, + per_block_scale.dtype, ) set_layer_name( dequantize_block_scale_layer, diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index e9c128ee40..dd52af215c 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -199,10 +199,10 @@ def test_resnet18_half(ir): torch._dynamo.reset() -# @unittest.skipIf( -# torch.cuda.get_device_capability() < (10, 0), -# "FP4 quantization requires compute capability 10.0 or later", -# ) +@unittest.skipIf( + torch.cuda.get_device_capability() < (10, 0), + "FP4 quantization requires compute capability 10.0 or later", +) @unittest.skipIf( not importlib.util.find_spec("modelopt"), "ModelOpt is required to run this test", @@ -226,7 +226,7 @@ def calibrate_loop(model): model(input_tensor) input_tensor = torch.randn(5, 16).cuda() - print(f"amax: {input_tensor.abs().amax()}") + print(f"lan added amax: {input_tensor.abs().amax()}") model = SimpleNetwork().eval().cuda() quant_cfg = mtq.NVFP4_DEFAULT_CFG @@ -249,7 +249,7 @@ def calibrate_loop(model): outputs_trt = trt_model(input_tensor) print(f"lan added outputs_trt: {outputs_trt}") print(f"lan added output_pyt: {output_pyt}") - assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=5e-1) + assert torch.allclose(output_pyt, outputs_trt, rtol=5e-1, atol=5e-1) @unittest.skipIf( From 559ada57331d09e243ceecbeba7720044b3301ca Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 7 May 2025 15:38:17 -0700 Subject: [PATCH 10/30] restructure the dynamic double quantize and static double quantize code --- .../dynamo/conversion/impl/quantize.py | 200 ++++++++---------- tests/py/dynamo/models/test_models_export.py | 5 +- 2 files changed, 94 insertions(+), 111 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py index cce91f2725..23464989ca 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Optional, Union import numpy as np import tensorrt as trt @@ -92,90 +92,64 @@ def nvfp4_quantize( f"lan added nvfp4_quantize entered: {target=} {source_ir=} {name=} {input_tensor.shape=} {input_tensor.dtype=} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}" ) with unset_fake_temporarily(): - if not isinstance(input_tensor, TRTTensor): - input_tensor = get_trt_tensor(ctx, input_tensor, name + "_input") if input_tensor.dtype not in ( trt.float32, trt.float16, trt.bfloat16, + torch.float32, + torch.float16, + torch.bfloat16, ): raise ValueError( f"dynamic_block_quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16 | bfloat16" ) - # TODO: ADD PADDING IF - - # calculate global scale (the global per-tensor scaling factor, should only contain 1 element) - amax = to_torch( - amax, None - ) # amax is calculated from input_tensor.abs().amax().float() - global_scale = torch.divide(amax, 6) - global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") - - if ".weight_quantizer" in name: - # calculate block scaling factor of weights - [n, k] = input_tensor.shape[-2:] - assert block_size != 0, "block_size must be non-zero" - assert k % block_size == 0, "k must be a multiple of block_size" - reshaped_input_tensor = input_tensor.reshape( - tuple(input_tensor.shape[:-2]) + (n, k // block_size, block_size) + if len(input_tensor.shape) not in (2, 3): + raise ValueError( + f"dynamic_block_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D" ) - per_block_amax = reshaped_input_tensor.abs().amax(dim=-1).float() - per_block_scale = torch.divide(per_block_amax, 6) + axis = len(input_tensor.shape) - 1 - per_block_scale = get_trt_tensor( - ctx, per_block_scale, name + "_per_block_scale" - ) + # TODO: ADD PADDING IF NEEDED + # TODO: ADD DYNAMIC SHAPE SUPPORT - # static double quantization is used for weights - quantized_data_in_fp4, quantized_block_scale_in_fp8 = ( - _static_double_quantize( - ctx, - target, - source_ir, - name, - input_tensor, - per_block_scale, - global_scale, - ) - ) - output = _block_double_dequantize( + global_scale = _calculate_global_scale(ctx, name, amax) + + if ".weight_quantizer" in name: + block_scale = _calculate_block_scale( ctx, - target, - source_ir, name, - quantized_data_in_fp4, - quantized_block_scale_in_fp8, - global_scale, + input_tensor, + block_size, ) - elif ".input_quantizer" in name: - # quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 - quantized_data_in_fp4, quantized_scale_in_fp8 = _dynamic_quantize( + input_tensor = get_trt_tensor(ctx, input_tensor, name + "_input") + output = _static_double_quantize( ctx, target, source_ir, name, input_tensor, + block_scale, global_scale, ) - # Add double DQ node - output = _block_double_dequantize( + elif ".input_quantizer" in name: + # quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 + output = _dynamic_double_quantize( ctx, target, source_ir, name, - quantized_data_in_fp4, - quantized_scale_in_fp8, + input_tensor, global_scale, - input_tensor.dtype, ) + else: raise ValueError( - f"dynamic_block_quantize converter received an input of {name} name. Supported names: weight_quantizer | input_quantizer" + f"quantizer received an input of {name}. Supported values: weight_quantizer | input_quantizer" ) return output -def _dynamic_quantize( +def _dynamic_double_quantize( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], @@ -186,7 +160,7 @@ def _dynamic_quantize( block_size: int = 16, output_type: trt.DataType = trt.DataType.FP4, scale_type: trt.DataType = trt.DataType.FP8, -) -> Tuple[TRTTensor, TRTTensor]: +) -> TRTTensor: """ quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 Parameters: @@ -202,20 +176,13 @@ def _dynamic_quantize( The axis to quantize. Default is -1 (the last axis). block_size : int The block size for quantization. Default is 16. - data_qtype : trt.DataType + output_type : trt.DataType The data type for quantized data. Default is FP4. - scale_qtype : trt.DataType + scale_type : trt.DataType The data type for block scale. Default is FP8. - Returns: - A tuple of two tensors: quantized data tensor in fp4 and quantized scale tensor in fp8. + """ - if len(input_tensor.shape) not in (2, 3): - raise ValueError( - f"dynamic_block_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D" - ) - if axis < 0: - axis = len(input_tensor.shape) + axis - # Add DYQ node + # dynamic quantize input tensor to fp4 dynamic_quantize_layer = ctx.net.add_dynamic_quantize( input_tensor, axis, @@ -229,51 +196,23 @@ def _dynamic_quantize( ) quantized_data_in_fp4 = dynamic_quantize_layer.get_output(0) quantized_scale_in_fp8 = dynamic_quantize_layer.get_output(1) - return quantized_data_in_fp4, quantized_scale_in_fp8 - -def _block_double_dequantize( - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input_tensor: TRTTensor, - scale: TRTTensor, - global_scale: TRTTensor, - dtype: trt.DataType = trt.DataType.FLOAT, -) -> TRTTensor: - """ - dequantize input_tensor from fp4 to dtype(default is float32) - Parameters: - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR] - name: str - input_tensor : Tensor (On GPU) - The input tensor. - scale : Tensor (On GPU) - The block scale tensor. - global_scale : Tensor (On GPU) - The global per-tensor scaling factor. It should contain only 1 element. - dtype : trt.DataType | str - The data type for dequantized data. Default is float32. - Returns: - The dequantized tensor. - """ - # dequantize scale from fp8 to dtype(default is float32) - dequantize_scale_layer = ctx.net.add_dequantize(scale, global_scale, dtype) + # dequantize scale from fp8 to orignal dtype(default is float32) + dequantize_scale_layer = ctx.net.add_dequantize( + quantized_scale_in_fp8, global_scale, input_tensor.dtype + ) set_layer_name( dequantize_scale_layer, target, name + "_dequantize_scale", source_ir ) dequantized_scale = dequantize_scale_layer.get_output(0) - # dequantize input_tensor from fp4 to dtype(default is float32) + # dequantize quantized_data_in_fp4 from fp4 to orignal dtype(default is float32) dequantize_data_layer = ctx.net.add_dequantize( - input_tensor, dequantized_scale, dtype + quantized_data_in_fp4, dequantized_scale, input_tensor.dtype ) set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) - dq_output = dequantize_data_layer.get_output(0) - return dq_output + dequantized_data = dequantize_data_layer.get_output(0) + return dequantized_data def _static_double_quantize( @@ -282,9 +221,9 @@ def _static_double_quantize( source_ir: Optional[SourceIR], name: str, input_tensor: TRTTensor, - per_block_scale: TRTTensor, + block_scale: TRTTensor, global_scale: TRTTensor, -) -> Tuple[TRTTensor, TRTTensor]: +) -> TRTTensor: """ Parameters: ctx: ConversionContext, @@ -293,28 +232,29 @@ def _static_double_quantize( name: str, input_tensor : Tensor (On GPU) The input tensor. - per_block_scale : Tensor (On GPU) + block_scale : Tensor (On GPU) The per-block scaling factor. global_scale : Tensor (On GPU) The global per-tensor scaling factor. It should contain only 1 element. Returns: A tuple of two tensors: quantized data tensor in fp4 and quantized block scaling factor tensor in fp8 """ - - block_scale_quantize_layer = ctx.net.add_quantize(per_block_scale, global_scale) + # quantize block scale to fp8 + block_scale_quantize_layer = ctx.net.add_quantize(block_scale, global_scale) set_layer_name( block_scale_quantize_layer, target, - name + "_per_block_scale_quantize", + name + "_block_scale_quantize", source_ir, ) block_scale_quantize_layer.set_output_type(0, trt.DataType.FP8) quantized_block_scale_in_fp8 = block_scale_quantize_layer.get_output(0) + # dequantize block scale from fp8 to original dtype(default is float32) dequantize_block_scale_layer = ctx.net.add_dequantize( quantized_block_scale_in_fp8, global_scale, - per_block_scale.dtype, + block_scale.dtype, ) set_layer_name( dequantize_block_scale_layer, @@ -322,12 +262,54 @@ def _static_double_quantize( name + "_dequantize_block_scale", source_ir, ) - dequantize_block_scale_layer.precision = trt.DataType.FP8 dequantized_block_scale = dequantize_block_scale_layer.get_output(0) + # quantize input tensor to fp4 data_quantize_layer = ctx.net.add_quantize(input_tensor, dequantized_block_scale) set_layer_name(data_quantize_layer, target, name + "_data_quantize", source_ir) data_quantize_layer.set_output_type(0, trt.DataType.FP4) quantized_data_in_fp4 = data_quantize_layer.get_output(0) - return quantized_data_in_fp4, quantized_block_scale_in_fp8 + # dequantize input tensor from fp4 to originaldtype(default is float32) + dequantize_data_layer = ctx.net.add_dequantize( + quantized_data_in_fp4, + dequantized_block_scale, + input_tensor.dtype, + ) + set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) + dequantized_data = dequantize_data_layer.get_output(0) + return dequantized_data + + +def _calculate_global_scale( + ctx: ConversionContext, + name: str, + amax: TRTTensor, +) -> TRTTensor: + # calculate global scale (the global per-tensor scaling factor, should only contain 1 element) + amax = to_torch( + amax, None + ) # amax is calculated from input_tensor.abs().amax().float() + global_scale = torch.divide(amax, 6 * 448) + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + return global_scale + + +def _calculate_block_scale( + ctx: ConversionContext, + name: str, + input_tensor: TRTTensor, + block_size: int, +) -> TRTTensor: + + [n, k] = input_tensor.shape[-2:] + assert block_size != 0, "block_size must be non-zero" + assert k % block_size == 0, "k must be a multiple of block_size" + reshaped_input_tensor = input_tensor.reshape( + tuple(input_tensor.shape[:-2]) + (n, k // block_size, block_size) + ) + block_amax = reshaped_input_tensor.abs().amax(dim=-1).float() + block_scale = torch.divide(block_amax, 6) + + block_scale = get_trt_tensor(ctx, block_scale, name + "_block_scale") + return block_scale diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index dd52af215c..8b3646bb24 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -215,7 +215,7 @@ def test_base_fp4(ir): class SimpleNetwork(torch.nn.Module): def __init__(self): super(SimpleNetwork, self).__init__() - self.linear1 = torch.nn.Linear(in_features=16, out_features=5) + self.linear1 = torch.nn.Linear(in_features=16, out_features=3) def forward(self, x): x = self.linear1(x) @@ -249,7 +249,7 @@ def calibrate_loop(model): outputs_trt = trt_model(input_tensor) print(f"lan added outputs_trt: {outputs_trt}") print(f"lan added output_pyt: {output_pyt}") - assert torch.allclose(output_pyt, outputs_trt, rtol=5e-1, atol=5e-1) + assert torch.allclose(output_pyt, outputs_trt, rtol=4e-1, atol=4e-1) @unittest.skipIf( @@ -284,6 +284,7 @@ def calibrate_loop(model): input_tensor = torch.randn(1, 10).cuda() model = SimpleNetwork().eval().cuda() quant_cfg = mtq.FP8_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) # model has FP8 qdq nodes at this point output_pyt = model(input_tensor) From bba1d795f0c0babae737366eaede759a644eeb60 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 14 May 2025 13:46:04 -0700 Subject: [PATCH 11/30] add test code --- py/torch_tensorrt/_enums.py | 2 + .../dynamo/conversion/aten_ops_converters.py | 2 +- .../dynamo/conversion/impl/__init__.py | 1 + .../dynamo/conversion/impl/quantize.py | 451 +++++++++--------- tests/py/dynamo/models/test_models_export.py | 8 +- 5 files changed, 236 insertions(+), 228 deletions(-) diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index 6d3250c681..e0a78e1a0b 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -409,6 +409,8 @@ def to( return trt.DataType.BOOL elif self == dtype.bf16: return trt.DataType.BF16 + elif self == dtype.f4: + return trt.DataType.FP4 elif use_default: return trt.DataType.FLOAT else: diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index fb02c975d1..08a1bb4ea4 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -635,7 +635,7 @@ def aten_ops_dynamic_block_quantize_op( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.quantize.nvfp4_quantize( + return impl.nvfp4_quantize.nvfp4_quantize( ctx, target, SourceIR.ATEN, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index df580b1516..1f2d9d0de1 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -14,6 +14,7 @@ matmul, nccl_ops, normalization, + nvfp4_quantize, pad, permutation, pool, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py index 23464989ca..192f9c648a 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -71,245 +71,250 @@ def quantize( return dq_output -def nvfp4_quantize( - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input_tensor: TRTTensor, - block_size: int, - amax: Union[np.ndarray, torch.Tensor], - num_bits: int, - exponent_bits: int, - scale_num_bits: int, - scale_exponent_bits: int, -) -> TRTTensor: - """ - Adds quantize and dequantize ops (QDQ) which quantize to FP4 based - on the output_type set and dequantizes them back. - """ - print( - f"lan added nvfp4_quantize entered: {target=} {source_ir=} {name=} {input_tensor.shape=} {input_tensor.dtype=} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}" - ) - with unset_fake_temporarily(): - if input_tensor.dtype not in ( - trt.float32, - trt.float16, - trt.bfloat16, - torch.float32, - torch.float16, - torch.bfloat16, - ): - raise ValueError( - f"dynamic_block_quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16 | bfloat16" - ) - if len(input_tensor.shape) not in (2, 3): - raise ValueError( - f"dynamic_block_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D" - ) - axis = len(input_tensor.shape) - 1 +# def nvfp4_quantize( +# ctx: ConversionContext, +# target: Target, +# source_ir: Optional[SourceIR], +# name: str, +# input_tensor: TRTTensor, +# block_size: int, +# amax: Union[np.ndarray, torch.Tensor], +# num_bits: int, +# exponent_bits: int, +# scale_num_bits: int, +# scale_exponent_bits: int, +# ) -> TRTTensor: +# """ +# Adds quantize and dequantize ops (QDQ) which quantize to FP4 based +# on the output_type set and dequantizes them back. +# """ +# print( +# f"lan added nvfp4_quantize entered: {target=} {source_ir=} {name=} {input_tensor.shape=} {input_tensor.dtype=} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}" +# ) +# with unset_fake_temporarily(): +# if input_tensor.dtype not in ( +# trt.float32, +# trt.float16, +# trt.bfloat16, +# torch.float32, +# torch.float16, +# torch.bfloat16, +# ): +# raise ValueError( +# f"dynamic_block_quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16 | bfloat16" +# ) +# if len(input_tensor.shape) not in (2, 3): +# raise ValueError( +# f"dynamic_block_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D" +# ) +# axis = len(input_tensor.shape) - 1 - # TODO: ADD PADDING IF NEEDED - # TODO: ADD DYNAMIC SHAPE SUPPORT +# # TODO: ADD PADDING IF NEEDED +# # TODO: ADD DYNAMIC SHAPE SUPPORT - global_scale = _calculate_global_scale(ctx, name, amax) +# global_scale = _calculate_global_scale(ctx, name, amax) - if ".weight_quantizer" in name: - block_scale = _calculate_block_scale( - ctx, - name, - input_tensor, - block_size, - ) - input_tensor = get_trt_tensor(ctx, input_tensor, name + "_input") - output = _static_double_quantize( - ctx, - target, - source_ir, - name, - input_tensor, - block_scale, - global_scale, - ) - elif ".input_quantizer" in name: - # quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 - output = _dynamic_double_quantize( - ctx, - target, - source_ir, - name, - input_tensor, - global_scale, - ) +# if ".weight_quantizer" in name: +# block_scale = _calculate_block_scale( +# ctx, +# name, +# input_tensor, +# block_size, +# ) +# input_tensor = get_trt_tensor(ctx, input_tensor, name + "_input") +# output = _static_double_quantize( +# ctx, +# target, +# source_ir, +# name, +# input_tensor, +# block_scale, +# global_scale, +# ) +# elif ".input_quantizer" in name: +# # quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 +# output = _dynamic_double_quantize( +# ctx, +# target, +# source_ir, +# name, +# input_tensor, +# global_scale, +# ) - else: - raise ValueError( - f"quantizer received an input of {name}. Supported values: weight_quantizer | input_quantizer" - ) - return output +# else: +# raise ValueError( +# f"quantizer received an input of {name}. Supported values: weight_quantizer | input_quantizer" +# ) +# return output -def _dynamic_double_quantize( - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input_tensor: TRTTensor, - global_scale: TRTTensor, - axis: int = -1, - block_size: int = 16, - output_type: trt.DataType = trt.DataType.FP4, - scale_type: trt.DataType = trt.DataType.FP8, -) -> TRTTensor: - """ - quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 - Parameters: - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR] - name: str - input_tensor : Tensor (On GPU) - The input tensor. - global_scale : Tensor (On GPU) - The global per-tensor scaling factor. It should contain only 1 element. - axis : int - The axis to quantize. Default is -1 (the last axis). - block_size : int - The block size for quantization. Default is 16. - output_type : trt.DataType - The data type for quantized data. Default is FP4. - scale_type : trt.DataType - The data type for block scale. Default is FP8. - - """ - # dynamic quantize input tensor to fp4 - dynamic_quantize_layer = ctx.net.add_dynamic_quantize( - input_tensor, - axis, - block_size, - output_type, - scale_type, - ) - dynamic_quantize_layer.set_input(1, global_scale) - set_layer_name( - dynamic_quantize_layer, target, name + "_dynamic_quantize", source_ir - ) - quantized_data_in_fp4 = dynamic_quantize_layer.get_output(0) - quantized_scale_in_fp8 = dynamic_quantize_layer.get_output(1) +# def _dynamic_double_quantize( +# ctx: ConversionContext, +# target: Target, +# source_ir: Optional[SourceIR], +# name: str, +# input_tensor: TRTTensor, +# global_scale: TRTTensor, +# axis: int = -1, +# block_size: int = 16, +# output_type: trt.DataType = trt.DataType.FP4, +# scale_type: trt.DataType = trt.DataType.FP8, +# ) -> TRTTensor: +# """ +# quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 +# Parameters: +# ctx: ConversionContext, +# target: Target, +# source_ir: Optional[SourceIR] +# name: str +# input_tensor : Tensor (On GPU) +# The input tensor. +# global_scale : Tensor (On GPU) +# The global per-tensor scaling factor. It should contain only 1 element. +# axis : int +# The axis to quantize. Default is -1 (the last axis). +# block_size : int +# The block size for quantization. Default is 16. +# output_type : trt.DataType +# The data type for quantized data. Default is FP4. +# scale_type : trt.DataType +# The data type for block scale. Default is FP8. - # dequantize scale from fp8 to orignal dtype(default is float32) - dequantize_scale_layer = ctx.net.add_dequantize( - quantized_scale_in_fp8, global_scale, input_tensor.dtype - ) - set_layer_name( - dequantize_scale_layer, target, name + "_dequantize_scale", source_ir - ) - dequantized_scale = dequantize_scale_layer.get_output(0) +# """ +# # dynamic quantize input tensor to fp4 +# dynamic_quantize_layer = ctx.net.add_dynamic_quantize( +# input_tensor, +# axis, +# block_size, +# output_type, +# scale_type, +# ) +# dynamic_quantize_layer.set_input(1, global_scale) +# set_layer_name( +# dynamic_quantize_layer, target, name + "_dynamic_quantize", source_ir +# ) +# quantized_data_in_fp4 = dynamic_quantize_layer.get_output(0) +# quantized_scale_in_fp8 = dynamic_quantize_layer.get_output(1) - # dequantize quantized_data_in_fp4 from fp4 to orignal dtype(default is float32) - dequantize_data_layer = ctx.net.add_dequantize( - quantized_data_in_fp4, dequantized_scale, input_tensor.dtype - ) - set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) - dequantized_data = dequantize_data_layer.get_output(0) - return dequantized_data +# # dequantize scale from fp8 to orignal dtype(default is float32) +# dequantize_scale_layer = ctx.net.add_dequantize( +# quantized_scale_in_fp8, global_scale, input_tensor.dtype +# ) +# set_layer_name( +# dequantize_scale_layer, target, name + "_dequantize_scale", source_ir +# ) +# dequantized_scale = dequantize_scale_layer.get_output(0) +# # dequantize quantized_data_in_fp4 from fp4 to orignal dtype(default is float32) +# dequantize_data_layer = ctx.net.add_dequantize( +# quantized_data_in_fp4, dequantized_scale, input_tensor.dtype +# ) +# dequantize_data_layer.axis = axis +# set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) +# dequantized_data = dequantize_data_layer.get_output(0) +# return dequantized_data -def _static_double_quantize( - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input_tensor: TRTTensor, - block_scale: TRTTensor, - global_scale: TRTTensor, -) -> TRTTensor: - """ - Parameters: - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input_tensor : Tensor (On GPU) - The input tensor. - block_scale : Tensor (On GPU) - The per-block scaling factor. - global_scale : Tensor (On GPU) - The global per-tensor scaling factor. It should contain only 1 element. - Returns: - A tuple of two tensors: quantized data tensor in fp4 and quantized block scaling factor tensor in fp8 - """ - # quantize block scale to fp8 - block_scale_quantize_layer = ctx.net.add_quantize(block_scale, global_scale) - set_layer_name( - block_scale_quantize_layer, - target, - name + "_block_scale_quantize", - source_ir, - ) - block_scale_quantize_layer.set_output_type(0, trt.DataType.FP8) - quantized_block_scale_in_fp8 = block_scale_quantize_layer.get_output(0) - # dequantize block scale from fp8 to original dtype(default is float32) - dequantize_block_scale_layer = ctx.net.add_dequantize( - quantized_block_scale_in_fp8, - global_scale, - block_scale.dtype, - ) - set_layer_name( - dequantize_block_scale_layer, - target, - name + "_dequantize_block_scale", - source_ir, - ) - dequantized_block_scale = dequantize_block_scale_layer.get_output(0) +# def _static_double_quantize( +# ctx: ConversionContext, +# target: Target, +# source_ir: Optional[SourceIR], +# name: str, +# input_tensor: TRTTensor, +# block_scale: TRTTensor, +# global_scale: TRTTensor, +# ) -> TRTTensor: +# """ +# Parameters: +# ctx: ConversionContext, +# target: Target, +# source_ir: Optional[SourceIR], +# name: str, +# input_tensor : Tensor (On GPU) +# The input tensor. +# block_scale : Tensor (On GPU) +# The per-block scaling factor. +# global_scale : Tensor (On GPU) +# The global per-tensor scaling factor. It should contain only 1 element. +# Returns: +# A tuple of two tensors: quantized data tensor in fp4 and quantized block scaling factor tensor in fp8 +# """ +# # quantize block scale to fp8 +# block_scale_quantize_layer = ctx.net.add_quantize(block_scale, global_scale) +# set_layer_name( +# block_scale_quantize_layer, +# target, +# name + "_block_scale_quantize", +# source_ir, +# ) +# block_scale_quantize_layer.set_output_type(0, trt.DataType.FP8) +# quantized_block_scale_in_fp8 = block_scale_quantize_layer.get_output(0) - # quantize input tensor to fp4 - data_quantize_layer = ctx.net.add_quantize(input_tensor, dequantized_block_scale) - set_layer_name(data_quantize_layer, target, name + "_data_quantize", source_ir) - data_quantize_layer.set_output_type(0, trt.DataType.FP4) - quantized_data_in_fp4 = data_quantize_layer.get_output(0) +# # dequantize block scale from fp8 to original dtype(default is float32) +# dequantize_block_scale_layer = ctx.net.add_dequantize( +# quantized_block_scale_in_fp8, +# global_scale, +# block_scale.dtype, +# ) +# set_layer_name( +# dequantize_block_scale_layer, +# target, +# name + "_dequantize_block_scale", +# source_ir, +# ) +# dequantized_block_scale = dequantize_block_scale_layer.get_output(0) - # dequantize input tensor from fp4 to originaldtype(default is float32) - dequantize_data_layer = ctx.net.add_dequantize( - quantized_data_in_fp4, - dequantized_block_scale, - input_tensor.dtype, - ) - set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) - dequantized_data = dequantize_data_layer.get_output(0) - return dequantized_data +# # quantize input tensor to fp4 +# data_quantize_layer = ctx.net.add_quantize(input_tensor, dequantized_block_scale) +# set_layer_name(data_quantize_layer, target, name + "_data_quantize", source_ir) +# data_quantize_layer.set_output_type(0, trt.DataType.FP4) +# quantized_data_in_fp4 = data_quantize_layer.get_output(0) +# # dequantize input tensor from fp4 to originaldtype(default is float32) +# dequantize_data_layer = ctx.net.add_dequantize( +# quantized_data_in_fp4, +# dequantized_block_scale, +# input_tensor.dtype, +# ) +# set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) +# dequantized_data = dequantize_data_layer.get_output(0) +# return dequantized_data -def _calculate_global_scale( - ctx: ConversionContext, - name: str, - amax: TRTTensor, -) -> TRTTensor: - # calculate global scale (the global per-tensor scaling factor, should only contain 1 element) - amax = to_torch( - amax, None - ) # amax is calculated from input_tensor.abs().amax().float() - global_scale = torch.divide(amax, 6 * 448) - global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") - return global_scale +# def _calculate_global_scale( +# ctx: ConversionContext, +# name: str, +# amax: TRTTensor, +# ) -> TRTTensor: +# # calculate global scale (the global per-tensor scaling factor, should only contain 1 element) +# if amax is None or amax == 0: +# amax = 1.0 +# amax = to_torch( +# amax, None +# ) # amax is calculated from input_tensor.abs().amax().float() +# global_scale = torch.divide(amax, 6) # 6*448 +# global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") +# return global_scale -def _calculate_block_scale( - ctx: ConversionContext, - name: str, - input_tensor: TRTTensor, - block_size: int, -) -> TRTTensor: - [n, k] = input_tensor.shape[-2:] - assert block_size != 0, "block_size must be non-zero" - assert k % block_size == 0, "k must be a multiple of block_size" - reshaped_input_tensor = input_tensor.reshape( - tuple(input_tensor.shape[:-2]) + (n, k // block_size, block_size) - ) - block_amax = reshaped_input_tensor.abs().amax(dim=-1).float() - block_scale = torch.divide(block_amax, 6) +# def _calculate_block_scale( +# ctx: ConversionContext, +# name: str, +# input_tensor: TRTTensor, +# block_size: int, +# ) -> TRTTensor: - block_scale = get_trt_tensor(ctx, block_scale, name + "_block_scale") - return block_scale +# [n, k] = input_tensor.shape[-2:] +# assert block_size != 0, "block_size must be non-zero" +# assert k % block_size == 0, "k must be a multiple of block_size" +# reshaped_input_tensor = input_tensor.reshape( +# tuple(input_tensor.shape[:-2]) + (n, k // block_size, block_size) +# ) +# amax = input_tensor.abs().amax().float() +# amax = torch.divide(amax, 6*448) +# block_amax = reshaped_input_tensor.abs().amax(dim=-1).float() +# block_scale = torch.divide(block_amax, 6) +# block_scale = torch.divide(block_scale, amax) +# block_scale = get_trt_tensor(ctx, block_scale, name + "_block_scale") +# return block_scale diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 8b3646bb24..52d746f995 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -199,10 +199,10 @@ def test_resnet18_half(ir): torch._dynamo.reset() -@unittest.skipIf( - torch.cuda.get_device_capability() < (10, 0), - "FP4 quantization requires compute capability 10.0 or later", -) +# @unittest.skipIf( +# torch.cuda.get_device_capability() < (10, 0), +# "FP4 quantization requires compute capability 10.0 or later", +# ) @unittest.skipIf( not importlib.util.find_spec("modelopt"), "ModelOpt is required to run this test", From f16e58a93ac2c3d8bd68db49e26283472d8ef2a9 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 14 May 2025 13:51:44 -0700 Subject: [PATCH 12/30] test --- .../dynamo/conversion/impl/nvfp4_quantize.py | 395 ++++++++++++++++++ 1 file changed, 395 insertions(+) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py new file mode 100644 index 0000000000..dd8c65069a --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py @@ -0,0 +1,395 @@ +from typing import Optional, Union + +import numpy as np +import tensorrt as trt +import torch +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + get_trt_tensor, + to_torch, +) +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTTensor + + +def nvfp4_quantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_tensor: TRTTensor, + block_size: int, + amax: Union[np.ndarray, torch.Tensor], + num_bits: int, + exponent_bits: int, + scale_num_bits: int, + scale_exponent_bits: int, +) -> TRTTensor: + """ + Adds quantize and dequantize ops (QDQ) which quantize to FP4 based + on the output_type set and dequantizes them back. + """ + print( + f"lan added nvfp4_quantize entered: {target=} {source_ir=} {name=} {input_tensor.shape=} {input_tensor.dtype=} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}" + ) + with unset_fake_temporarily(): + axis = len(input_tensor.shape) - 1 + global_scale = _calculate_global_scale(ctx, name, amax) + if ".weight_quantizer" in name: + _test_weights_scaling_factor(input_tensor, global_scale) + output = _static_double_quantize_without_constant_folding( + ctx, + target, + source_ir, + name, + input_tensor, + global_scale, + axis, + ) + elif ".input_quantizer" in name: + # quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 + output = _dynamic_double_quantize( + ctx, + target, + source_ir, + name, + input_tensor, + global_scale, + axis, + ) + + else: + raise ValueError( + f"quantizer received an input of {name}. Supported values: weight_quantizer | input_quantizer" + ) + return output + + +def _dynamic_double_quantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_tensor: torch.Tensor, + global_scale: torch.Tensor, + axis: int = -1, + block_size: int = 16, + output_type: trt.DataType = trt.DataType.FP4, + scale_type: trt.DataType = trt.DataType.FP8, +) -> TRTTensor: + """ + quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR] + name: str + input_tensor : Tensor (On GPU) + The input tensor. + global_scale : Tensor (On GPU) + The global per-tensor scaling factor. It should contain only 1 element. + axis : int + The axis to quantize. Default is -1 (the last axis). + block_size : int + The block size for quantization. Default is 16. + output_type : trt.DataType + The data type for quantized data. Default is FP4. + scale_type : trt.DataType + The data type for block scale. Default is FP8. + + """ + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + # dynamic quantize input tensor to fp4 + dynamic_quantize_layer = ctx.net.add_dynamic_quantize( + input_tensor, + axis, + block_size, + output_type, + scale_type, + ) + dynamic_quantize_layer.set_input(1, global_scale) + set_layer_name( + dynamic_quantize_layer, target, name + "_dynamic_quantize", source_ir + ) + quantized_data_in_fp4 = dynamic_quantize_layer.get_output(0) + quantized_scale_in_fp8 = dynamic_quantize_layer.get_output(1) + + # dequantize scale from fp8 to orignal dtype(default is float32) + dequantize_scale_layer = ctx.net.add_dequantize( + quantized_scale_in_fp8, global_scale, input_tensor.dtype + ) + set_layer_name( + dequantize_scale_layer, target, name + "_dequantize_scale", source_ir + ) + dequantized_scale = dequantize_scale_layer.get_output(0) + + # dequantize quantized_data_in_fp4 from fp4 to orignal dtype(default is float32) + dequantize_data_layer = ctx.net.add_dequantize( + quantized_data_in_fp4, dequantized_scale, input_tensor.dtype + ) + dequantize_data_layer.axis = axis + set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) + dequantized_data = dequantize_data_layer.get_output(0) + return dequantized_data + + +# TODO: to remove it this is to make sure our global scale and block scale calculation is correct during debugging +def _test_weights_scaling_factor(weights_tensor, global_scale): + + import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor + import modelopt.onnx.quantization.quant_utils as quant_utils + + weights_scaling_factor_2 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor_2( + weights_tensor + ) + torch.allclose(weights_scaling_factor_2, global_scale) + + block_scale_f32 = quant_utils.get_weights_scaling_factor( + weights_tensor.numpy(), 16, np.float32(global_scale) + ) + block_scale_f32 = torch.from_numpy(block_scale_f32) + + block_scale = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( + weights_tensor, 16, global_scale, True + )[0] + torch.allclose(block_scale_f32, block_scale) + block_scale_fp8 = block_scale.to(torch.float8_e4m3fn) + + +def _static_double_quantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor: torch.Tensor, + global_scale: torch.Tensor, + axis: int, +) -> TRTTensor: + """ + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor : Tensor (On GPU) + The input tensor for weights. + global_scale : Tensor (On GPU) + The global per-tensor scaling factor. It should contain only 1 element. + axis: int + The axis to quantize. Default is -1 (the last axis). + Returns: + quantized data tensor in fp4 + """ + + import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor + + # import modelopt.onnx.quantization.quant_utils as quant_utils + + block_scale_fp32 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( + weights_tensor, 16, global_scale, True + )[0] + block_scale_fp8 = block_scale_fp32.to(torch.float8_e4m3fn) + + global_scale = to_torch(global_scale, None) + + # # TODO: issue1: not sure whether we need to quantize the weights tensor here, due to Icast layer does not support cast + # IBuilder::buildSerializedNetwork: Error Code 4: API Usage Error (Cast ITensor linear1.weight_quantizer/dynamic_block_quantize_op_1_weights_tensor_scaled_output from DataType.FLOAT to DataType.FP4 - [unknown_ir_ops]-[linear1.weight_quantizer/dynamic_block_quantize_op_1_cast_weights_tensor_scaled_to_fp4]: unsupported input type and output type for ICastLayer, unsupported types are: {FP8, Int4, FP4}, current input type: Float, output type: FP4) + # reference https://gitlab-master.nvidia.com/omniml/modelopt/-/blob/main/modelopt/onnx/quantization/qdq_utils.py#L955 + # weights_tensor_scaled = quant_utils.quantize(weights_tensor.numpy(), 16, block_scale_fp32.numpy(),global_scale.numpy()) + # weights_tensor_scaled = torch.from_numpy(weights_tensor_scaled) + # weights_tensor_scaled = get_trt_tensor(ctx, weights_tensor_scaled, name + "_weights_tensor_scaled") + # weights_fp4 = cast_trt_tensor(ctx, weights_tensor_scaled, trt.DataType.FP4, name + "_cast_weights_tensor_scaled_to_fp4") + + # # TODO: issue2: weights_tensor_scaled is in torch.uint8 format not sure how can this to be converted into float4_e2m1fn_x2 + # reference: https://gitlab-master.nvidia.com/omniml/modelopt/-/blob/main/modelopt/core/torch/quantization/qtensor/nvfp4_tensor.py#L136 + weights_tensor_scaled = nvfp4_tensor.NVFP4QTensor.quantize( + weights_tensor, + 16, + block_scale_fp32, + global_scale, + )[0]._quantized_data + + # # TODO: issue3: torch does not support convert to float4_e2m1fn_x2 directly got RuntimeError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2' + # weights_fp4 = weights_tensor_scaled.to(torch.float4_e2m1fn_x2) + # weights_fp4 = get_trt_tensor(ctx, weights_fp4, name + "_weights_fp4") + + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + block_scale = get_trt_tensor(ctx, block_scale_fp32, name + "_block_scale") + block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8") + # # quantize block scale to fp8 + # block_scale_quantize_layer = ctx.net.add_quantize(block_scale, global_scale) + # set_layer_name( + # block_scale_quantize_layer, + # target, + # name + "_block_scale_quantize", + # source_ir, + # ) + # block_scale_quantize_layer.set_output_type(0, trt.DataType.FP8) + # quantized_block_scale_in_fp8 = block_scale_quantize_layer.get_output(0) + + # dequantize block scale from fp8 to float32 + dequantize_block_scale_layer = ctx.net.add_dequantize( + block_scale_fp8, + global_scale, + block_scale.dtype, + ) + set_layer_name( + dequantize_block_scale_layer, + target, + name + "_dequantize_block_scale", + source_ir, + ) + dequantized_block_scale = dequantize_block_scale_layer.get_output(0) + + # dequantize weights tensor from fp4 to originaldtype(default is float32) + dequantize_data_layer = ctx.net.add_dequantize( + weights_fp4, + dequantized_block_scale, + trt.DataType.FLOAT, + ) + set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) + dequantized_data = dequantize_data_layer.get_output(0) + return dequantized_data + + +def _static_double_quantize_without_constant_folding( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor: torch.Tensor, + global_scale: torch.Tensor, + axis: int, +) -> TRTTensor: + """ + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor : Tensor (On GPU) + The input tensor for weights. + global_scale : Tensor (On GPU) + The global per-tensor scaling factor. It should contain only 1 element. + axis: int + The axis to quantize. Default is -1 (the last axis). + Returns: + quantized data tensor in fp4 + """ + + import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor + + # import modelopt.onnx.quantization.quant_utils as quant_utils + + block_scale = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( + weights_tensor, 16, global_scale, True + )[0] + global_scale = to_torch(global_scale, None) + + # block_scale_fp8 = block_scale.to(torch.float8_e4m3fn) + # block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8") + + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + block_scale = get_trt_tensor(ctx, block_scale, name + "_block_scale") + weights_tensor = get_trt_tensor(ctx, weights_tensor, name + "_weights_tensor") + + # quantize block scale to fp8 + block_scale_quantize_layer = ctx.net.add_quantize(block_scale, global_scale) + set_layer_name( + block_scale_quantize_layer, + target, + name + "_block_scale_quantize_to_fp8", + source_ir, + ) + block_scale_quantize_layer.set_output_type(0, trt.DataType.FP8) + block_scale_fp8 = block_scale_quantize_layer.get_output(0) + + # dequantize block scale from fp8 to float32 + dequantize_block_scale_layer = ctx.net.add_dequantize( + block_scale_fp8, + global_scale, + block_scale.dtype, + ) + set_layer_name( + dequantize_block_scale_layer, + target, + name + "_dequantize_block_scale_from_fp8", + source_ir, + ) + dequantized_block_scale = dequantize_block_scale_layer.get_output(0) + + # quantize weights tensor to fp4 + quantize_weights_layer = ctx.net.add_quantize( + weights_tensor, dequantized_block_scale + ) + set_layer_name( + quantize_weights_layer, + target, + name + "_quantize_weights_to_fp4", + source_ir, + ) + quantize_weights_layer.set_output_type(0, trt.DataType.FP4) + weights_fp4 = quantize_weights_layer.get_output(0) + + # dequantize weights tensor from fp4 to originaldtype(default is float32) + dequantize_weights_layer = ctx.net.add_dequantize( + weights_fp4, + dequantized_block_scale, + trt.DataType.FLOAT, + ) + set_layer_name( + dequantize_weights_layer, + target, + name + "_dequantize_weights_from_fp4", + source_ir, + ) + dequantized_data = dequantize_weights_layer.get_output(0) + return dequantized_data + + +def _calculate_global_scale( + ctx: ConversionContext, + name: str, + amax: torch.Tensor, +) -> torch.Tensor: + # calculate global scale (the global per-tensor scaling factor, should only contain 1 element) + if amax is None or amax == 0: + amax = 1.0 + amax = to_torch( + amax, None + ) # amax is calculated from input_tensor.abs().amax().float() + global_scale = torch.divide(amax, 6 * 448) + if global_scale == 0: + global_scale = 1.0 + return global_scale + + +def _calculate_block_scale( + ctx: ConversionContext, + name: str, + weights_tensor: TRTTensor, + block_size: int, +) -> TRTTensor: + amax = weights_tensor.abs().amax().float() + # reference: https://gitlab-master.nvidia.com/omniml/modelopt/-/blob/main/modelopt/onnx/quantization/quant_utils.py#L122 + weights_scaling_factor_2 = amax / 6 / 448 + if weights_scaling_factor_2 == 0: + weights_scaling_factor_2 = 1.0 + + # reference: https://gitlab-master.nvidia.com/omniml/modelopt/-/blob/main/modelopt/onnx/quantization/quant_utils.py#L131 + [n, k] = weights_tensor.shape[-2:] + assert block_size != 0, "block_size must be non-zero" + assert k % block_size == 0, "k must be a multiple of block_size" + reshaped_input_tensor = weights_tensor.reshape( + tuple(weights_tensor.shape[:-2]) + (n, k // block_size, block_size) + ) + + per_block_amax = reshaped_input_tensor.abs().amax(dim=-1).float() + per_block_scale = torch.divide(per_block_amax, 6) + q_per_block_scale = torch.divide(per_block_scale, weights_scaling_factor_2) + # TODO:set all zero values in scale to 1.0 + # block_scale = get_trt_tensor(ctx, q_per_block_scale, name + "_block_scale") + return q_per_block_scale From 06c81261447c7fc817e969745d595dd2ec73e4ad Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 14 May 2025 14:16:21 -0700 Subject: [PATCH 13/30] test --- examples/dynamo/vgg16_ptq.py | 4 ++++ py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/examples/dynamo/vgg16_ptq.py b/examples/dynamo/vgg16_ptq.py index 7fa943040e..0ed8772a44 100644 --- a/examples/dynamo/vgg16_ptq.py +++ b/examples/dynamo/vgg16_ptq.py @@ -200,6 +200,8 @@ def calibrate_loop(model): quant_cfg = mtq.INT8_DEFAULT_CFG elif args.quantize_type == "fp8": quant_cfg = mtq.FP8_DEFAULT_CFG +elif args.quantize_type == "fp4": + quant_cfg = mtq.NVFP4_DEFAULT_CFG # PTQ with in-place replacement to quantized modules mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) # model has FP8 qdq nodes at this point @@ -239,6 +241,8 @@ def calibrate_loop(model): enabled_precisions = {torch.int8} elif args.quantize_type == "fp8": enabled_precisions = {torch.float8_e4m3fn} + elif args.quantize_type == "fp4": + enabled_precisions = {torch.float4_e2m1fn_x2} trt_model = torchtrt.dynamo.compile( exp_program, inputs=[input_tensor], diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py index dd8c65069a..2458350715 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py @@ -35,6 +35,10 @@ def nvfp4_quantize( print( f"lan added nvfp4_quantize entered: {target=} {source_ir=} {name=} {input_tensor.shape=} {input_tensor.dtype=} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}" ) + if len(input_tensor.shape) not in (2, 3): + raise ValueError( + f"nvfp4_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D" + ) with unset_fake_temporarily(): axis = len(input_tensor.shape) - 1 global_scale = _calculate_global_scale(ctx, name, amax) From 868949c5d82741c0e567dddb334a44d8c21ac0b2 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 15 May 2025 09:10:55 -0700 Subject: [PATCH 14/30] test --- .../dynamo/conversion/converter_utils.py | 25 +++- .../dynamo/conversion/impl/nvfp4_quantize.py | 126 ++++++++++++++++-- 2 files changed, 135 insertions(+), 16 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 685f40b254..298687d38e 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -361,13 +361,33 @@ def create_constant( shape = list(torch_value.shape) if torch_value is not None: - if torch_value.dtype == torch.bfloat16: + if torch_value.dtype == torch.float8_e4m3fn: + weights = trt.Weights(type=trt.DataType.FP8, ptr=torch_value.data_ptr(), count=torch_value.numel()) + constant = ctx.net.add_constant( + shape, + weights, + ) + constant.name = name + return constant.get_output(0) + + if torch_value.dtype in [torch.bfloat16]: torch_value_fp32 = torch_value.to(torch.float32) numpy_value = torch_value_fp32.numpy() else: numpy_value = torch_value.numpy() - ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1) + + if torch_value.dtype == torch.uint8: + weights = trt.Weights(type=trt.DataType.FP4, ptr=numpy_value.ctypes.data, count=numpy_value.size * 2) + shape[1] = shape[1] * 2 + constant = ctx.net.add_constant( + shape, + weights, + ) + constant.name = name + return constant.get_output(0) + + constant = ctx.net.add_constant( shape, numpy_value, @@ -381,7 +401,6 @@ def create_constant( trt.DataType.BF16, name + "_bf16_cast", ) - return constant.get_output(0) else: raise ValueError( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py index 2458350715..e2b6677488 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py @@ -44,7 +44,7 @@ def nvfp4_quantize( global_scale = _calculate_global_scale(ctx, name, amax) if ".weight_quantizer" in name: _test_weights_scaling_factor(input_tensor, global_scale) - output = _static_double_quantize_without_constant_folding( + output = _static_double_quantize_solution_2( ctx, target, source_ir, @@ -162,8 +162,14 @@ def _test_weights_scaling_factor(weights_tensor, global_scale): torch.allclose(block_scale_f32, block_scale) block_scale_fp8 = block_scale.to(torch.float8_e4m3fn) - -def _static_double_quantize( +# # solution1: got quantized weights tensor in fp32 torch tensor format(but it cannot be directly used by TensorRT, needs to be packed into uint8 format so that TensorRT can understand it) +# # reference https://gitlab-master.nvidia.com/omniml/modelopt/-/blob/main/modelopt/onnx/quantization/qdq_utils.py#L952 +# this is the fp4 quantized weight represented in fp32 format: w_f32 = quantize(w32, block_size, sw_f32_per_block, sw_f32_per_tensor) +# this is the cast to fp4 format in onnx: w_f4 = Cast.eval(w_f32, to=onnx.TensorProto.FLOAT4E2M1) +# issue: there is no equivalent op in TensorRT API to cast fp32 to fp4 +# I don't know how to create a fp4 TRTTensor from fp32 torch tensor +# tried with ICastLayer but it does not support fp4 cast, +def _static_double_quantize_solution_1( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], @@ -190,7 +196,7 @@ def _static_double_quantize( import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor - # import modelopt.onnx.quantization.quant_utils as quant_utils + import modelopt.onnx.quantization.quant_utils as quant_utils block_scale_fp32 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( weights_tensor, 16, global_scale, True @@ -199,22 +205,25 @@ def _static_double_quantize( global_scale = to_torch(global_scale, None) - # # TODO: issue1: not sure whether we need to quantize the weights tensor here, due to Icast layer does not support cast - # IBuilder::buildSerializedNetwork: Error Code 4: API Usage Error (Cast ITensor linear1.weight_quantizer/dynamic_block_quantize_op_1_weights_tensor_scaled_output from DataType.FLOAT to DataType.FP4 - [unknown_ir_ops]-[linear1.weight_quantizer/dynamic_block_quantize_op_1_cast_weights_tensor_scaled_to_fp4]: unsupported input type and output type for ICastLayer, unsupported types are: {FP8, Int4, FP4}, current input type: Float, output type: FP4) - # reference https://gitlab-master.nvidia.com/omniml/modelopt/-/blob/main/modelopt/onnx/quantization/qdq_utils.py#L955 - # weights_tensor_scaled = quant_utils.quantize(weights_tensor.numpy(), 16, block_scale_fp32.numpy(),global_scale.numpy()) - # weights_tensor_scaled = torch.from_numpy(weights_tensor_scaled) - # weights_tensor_scaled = get_trt_tensor(ctx, weights_tensor_scaled, name + "_weights_tensor_scaled") - # weights_fp4 = cast_trt_tensor(ctx, weights_tensor_scaled, trt.DataType.FP4, name + "_cast_weights_tensor_scaled_to_fp4") - # # TODO: issue2: weights_tensor_scaled is in torch.uint8 format not sure how can this to be converted into float4_e2m1fn_x2 + # error from + weights_tensor_scaled = quant_utils.quantize(weights_tensor.numpy(), 16, block_scale_fp32.numpy(),global_scale.numpy()) + weights_tensor_scaled = torch.from_numpy(weights_tensor_scaled) + weights_tensor_scaled = get_trt_tensor(ctx, weights_tensor_scaled, name + "_weights_tensor_scaled") + # tried with ICastLayer but it does not support it, + weights_fp4 = cast_trt_tensor(ctx, weights_tensor_scaled, trt.DataType.FP4, name + "_cast_weights_tensor_scaled_to_fp4") + + # # solution2: got quantized weights tensor in uint8 torch tensor format which TensorRT can directly understand it + # issue2: weights_tensor_scaled is in torch.uint8 format not sure how can this to be converted into float4_e2m1fn_x2 # reference: https://gitlab-master.nvidia.com/omniml/modelopt/-/blob/main/modelopt/core/torch/quantization/qtensor/nvfp4_tensor.py#L136 + # weights_tensor_scaled = nvfp4_tensor.NVFP4QTensor.quantize( weights_tensor, 16, block_scale_fp32, global_scale, )[0]._quantized_data + weights_fp4_represented_in_uint8 = get_trt_tensor(ctx, weights_tensor_scaled, name + "_weights_fp4_represented_in_uint8") # # TODO: issue3: torch does not support convert to float4_e2m1fn_x2 directly got RuntimeError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2' # weights_fp4 = weights_tensor_scaled.to(torch.float4_e2m1fn_x2) @@ -249,11 +258,102 @@ def _static_double_quantize( dequantized_block_scale = dequantize_block_scale_layer.get_output(0) # dequantize weights tensor from fp4 to originaldtype(default is float32) + breakpoint() dequantize_data_layer = ctx.net.add_dequantize( - weights_fp4, + weights_fp4_represented_in_uint8, + dequantized_block_scale, + trt.DataType.FLOAT, + ) + dequantize_data_layer.precision = trt.DataType.FP4 + set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) + dequantized_data = dequantize_data_layer.get_output(0) + return dequantized_data + + + +def _static_double_quantize_solution_2( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor: torch.Tensor, + global_scale: torch.Tensor, + axis: int, +) -> TRTTensor: + """ + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor : Tensor (On GPU) + The input tensor for weights. + global_scale : Tensor (On GPU) + The global per-tensor scaling factor. It should contain only 1 element. + axis: int + The axis to quantize. Default is -1 (the last axis). + Returns: + quantized data tensor in fp4 + """ + + import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor + + block_scale_fp32 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( + weights_tensor, 16, global_scale, True + )[0] + block_scale_fp8 = block_scale_fp32.to(torch.float8_e4m3fn) + global_scale = to_torch(global_scale, None) + + # # solution2: got quantized weights tensor in uint8 torch tensor format which TensorRT can directly understand it + # reference: https://gitlab-master.nvidia.com/omniml/modelopt/-/blob/main/modelopt/core/torch/quantization/qtensor/nvfp4_tensor.py#L136 + # issue2: fp4 quantized weights tensor represented in a uint8 torch tensor format not sure how can I create a TRTTensor out of it + # https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/Graph/Layers.html#tensorrt.IConstantLayer does not support uint8 input + weights_tensor_scaled = nvfp4_tensor.NVFP4QTensor.quantize( + weights_tensor, + 16, + block_scale_fp32, + global_scale, + )[0]._quantized_data + + weights_fp4_represented_in_uint8 = get_trt_tensor(ctx, weights_tensor_scaled, name + "_weights_fp4_represented_in_uint8") + + + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + block_scale = get_trt_tensor(ctx, block_scale_fp32, name + "_block_scale") + block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8") + # # quantize block scale to fp8 + # block_scale_quantize_layer = ctx.net.add_quantize(block_scale, global_scale) + # set_layer_name( + # block_scale_quantize_layer, + # target, + # name + "_block_scale_quantize", + # source_ir, + # ) + # block_scale_quantize_layer.set_output_type(0, trt.DataType.FP8) + # quantized_block_scale_in_fp8 = block_scale_quantize_layer.get_output(0) + + # dequantize block scale from fp8 to float32 + dequantize_block_scale_layer = ctx.net.add_dequantize( + block_scale_fp8, + global_scale, + block_scale.dtype, + ) + set_layer_name( + dequantize_block_scale_layer, + target, + name + "_dequantize_block_scale", + source_ir, + ) + dequantized_block_scale = dequantize_block_scale_layer.get_output(0) + + # dequantize weights tensor from fp4 to originaldtype(default is float32) + breakpoint() + dequantize_data_layer = ctx.net.add_dequantize( + weights_fp4_represented_in_uint8, dequantized_block_scale, trt.DataType.FLOAT, ) + dequantize_data_layer.precision = trt.DataType.FP4 set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) dequantized_data = dequantize_data_layer.get_output(0) return dequantized_data From 5134a2ce0730c7853b131bcae57a6ea23e21098e Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 15 May 2025 09:52:39 -0700 Subject: [PATCH 15/30] test --- .../dynamo/conversion/impl/nvfp4_quantize.py | 270 +----------------- 1 file changed, 12 insertions(+), 258 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py index e2b6677488..1b43f98181 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py @@ -44,7 +44,7 @@ def nvfp4_quantize( global_scale = _calculate_global_scale(ctx, name, amax) if ".weight_quantizer" in name: _test_weights_scaling_factor(input_tensor, global_scale) - output = _static_double_quantize_solution_2( + output = _static_double_quantize( ctx, target, source_ir, @@ -141,7 +141,10 @@ def _dynamic_double_quantize( # TODO: to remove it this is to make sure our global scale and block scale calculation is correct during debugging -def _test_weights_scaling_factor(weights_tensor, global_scale): +def _test_weights_scaling_factor( + weights_tensor: torch.Tensor, + global_scale: torch.Tensor +) -> None: import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor import modelopt.onnx.quantization.quant_utils as quant_utils @@ -162,14 +165,8 @@ def _test_weights_scaling_factor(weights_tensor, global_scale): torch.allclose(block_scale_f32, block_scale) block_scale_fp8 = block_scale.to(torch.float8_e4m3fn) -# # solution1: got quantized weights tensor in fp32 torch tensor format(but it cannot be directly used by TensorRT, needs to be packed into uint8 format so that TensorRT can understand it) -# # reference https://gitlab-master.nvidia.com/omniml/modelopt/-/blob/main/modelopt/onnx/quantization/qdq_utils.py#L952 -# this is the fp4 quantized weight represented in fp32 format: w_f32 = quantize(w32, block_size, sw_f32_per_block, sw_f32_per_tensor) -# this is the cast to fp4 format in onnx: w_f4 = Cast.eval(w_f32, to=onnx.TensorProto.FLOAT4E2M1) -# issue: there is no equivalent op in TensorRT API to cast fp32 to fp4 -# I don't know how to create a fp4 TRTTensor from fp32 torch tensor -# tried with ICastLayer but it does not support fp4 cast, -def _static_double_quantize_solution_1( + +def _static_double_quantize( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], @@ -196,147 +193,27 @@ def _static_double_quantize_solution_1( import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor - import modelopt.onnx.quantization.quant_utils as quant_utils - - block_scale_fp32 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( - weights_tensor, 16, global_scale, True + block_scale_fp8 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( + weights_tensor, 16, global_scale, )[0] - block_scale_fp8 = block_scale_fp32.to(torch.float8_e4m3fn) - - global_scale = to_torch(global_scale, None) - - - # error from - weights_tensor_scaled = quant_utils.quantize(weights_tensor.numpy(), 16, block_scale_fp32.numpy(),global_scale.numpy()) - weights_tensor_scaled = torch.from_numpy(weights_tensor_scaled) - weights_tensor_scaled = get_trt_tensor(ctx, weights_tensor_scaled, name + "_weights_tensor_scaled") - # tried with ICastLayer but it does not support it, - weights_fp4 = cast_trt_tensor(ctx, weights_tensor_scaled, trt.DataType.FP4, name + "_cast_weights_tensor_scaled_to_fp4") - # # solution2: got quantized weights tensor in uint8 torch tensor format which TensorRT can directly understand it - # issue2: weights_tensor_scaled is in torch.uint8 format not sure how can this to be converted into float4_e2m1fn_x2 - # reference: https://gitlab-master.nvidia.com/omniml/modelopt/-/blob/main/modelopt/core/torch/quantization/qtensor/nvfp4_tensor.py#L136 - # weights_tensor_scaled = nvfp4_tensor.NVFP4QTensor.quantize( weights_tensor, 16, - block_scale_fp32, + block_scale_fp8, global_scale, )[0]._quantized_data - weights_fp4_represented_in_uint8 = get_trt_tensor(ctx, weights_tensor_scaled, name + "_weights_fp4_represented_in_uint8") - # # TODO: issue3: torch does not support convert to float4_e2m1fn_x2 directly got RuntimeError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2' - # weights_fp4 = weights_tensor_scaled.to(torch.float4_e2m1fn_x2) - # weights_fp4 = get_trt_tensor(ctx, weights_fp4, name + "_weights_fp4") - - global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") - block_scale = get_trt_tensor(ctx, block_scale_fp32, name + "_block_scale") block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8") - # # quantize block scale to fp8 - # block_scale_quantize_layer = ctx.net.add_quantize(block_scale, global_scale) - # set_layer_name( - # block_scale_quantize_layer, - # target, - # name + "_block_scale_quantize", - # source_ir, - # ) - # block_scale_quantize_layer.set_output_type(0, trt.DataType.FP8) - # quantized_block_scale_in_fp8 = block_scale_quantize_layer.get_output(0) - - # dequantize block scale from fp8 to float32 - dequantize_block_scale_layer = ctx.net.add_dequantize( - block_scale_fp8, - global_scale, - block_scale.dtype, - ) - set_layer_name( - dequantize_block_scale_layer, - target, - name + "_dequantize_block_scale", - source_ir, - ) - dequantized_block_scale = dequantize_block_scale_layer.get_output(0) - - # dequantize weights tensor from fp4 to originaldtype(default is float32) - breakpoint() - dequantize_data_layer = ctx.net.add_dequantize( - weights_fp4_represented_in_uint8, - dequantized_block_scale, - trt.DataType.FLOAT, - ) - dequantize_data_layer.precision = trt.DataType.FP4 - set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) - dequantized_data = dequantize_data_layer.get_output(0) - return dequantized_data - - - -def _static_double_quantize_solution_2( - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - weights_tensor: torch.Tensor, - global_scale: torch.Tensor, - axis: int, -) -> TRTTensor: - """ - Parameters: - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - weights_tensor : Tensor (On GPU) - The input tensor for weights. - global_scale : Tensor (On GPU) - The global per-tensor scaling factor. It should contain only 1 element. - axis: int - The axis to quantize. Default is -1 (the last axis). - Returns: - quantized data tensor in fp4 - """ - - import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor - - block_scale_fp32 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( - weights_tensor, 16, global_scale, True - )[0] - block_scale_fp8 = block_scale_fp32.to(torch.float8_e4m3fn) global_scale = to_torch(global_scale, None) - - # # solution2: got quantized weights tensor in uint8 torch tensor format which TensorRT can directly understand it - # reference: https://gitlab-master.nvidia.com/omniml/modelopt/-/blob/main/modelopt/core/torch/quantization/qtensor/nvfp4_tensor.py#L136 - # issue2: fp4 quantized weights tensor represented in a uint8 torch tensor format not sure how can I create a TRTTensor out of it - # https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/Graph/Layers.html#tensorrt.IConstantLayer does not support uint8 input - weights_tensor_scaled = nvfp4_tensor.NVFP4QTensor.quantize( - weights_tensor, - 16, - block_scale_fp32, - global_scale, - )[0]._quantized_data - - weights_fp4_represented_in_uint8 = get_trt_tensor(ctx, weights_tensor_scaled, name + "_weights_fp4_represented_in_uint8") - - global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") - block_scale = get_trt_tensor(ctx, block_scale_fp32, name + "_block_scale") - block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8") - # # quantize block scale to fp8 - # block_scale_quantize_layer = ctx.net.add_quantize(block_scale, global_scale) - # set_layer_name( - # block_scale_quantize_layer, - # target, - # name + "_block_scale_quantize", - # source_ir, - # ) - # block_scale_quantize_layer.set_output_type(0, trt.DataType.FP8) - # quantized_block_scale_in_fp8 = block_scale_quantize_layer.get_output(0) + weights_fp4_represented_in_uint8 = get_trt_tensor(ctx, weights_tensor_scaled, name + "_weights_fp4_represented_in_uint8") # dequantize block scale from fp8 to float32 dequantize_block_scale_layer = ctx.net.add_dequantize( block_scale_fp8, global_scale, - block_scale.dtype, + trt.DataType.FLOAT, ) set_layer_name( dequantize_block_scale_layer, @@ -347,7 +224,6 @@ def _static_double_quantize_solution_2( dequantized_block_scale = dequantize_block_scale_layer.get_output(0) # dequantize weights tensor from fp4 to originaldtype(default is float32) - breakpoint() dequantize_data_layer = ctx.net.add_dequantize( weights_fp4_represented_in_uint8, dequantized_block_scale, @@ -359,101 +235,6 @@ def _static_double_quantize_solution_2( return dequantized_data -def _static_double_quantize_without_constant_folding( - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - weights_tensor: torch.Tensor, - global_scale: torch.Tensor, - axis: int, -) -> TRTTensor: - """ - Parameters: - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - weights_tensor : Tensor (On GPU) - The input tensor for weights. - global_scale : Tensor (On GPU) - The global per-tensor scaling factor. It should contain only 1 element. - axis: int - The axis to quantize. Default is -1 (the last axis). - Returns: - quantized data tensor in fp4 - """ - - import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor - - # import modelopt.onnx.quantization.quant_utils as quant_utils - - block_scale = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( - weights_tensor, 16, global_scale, True - )[0] - global_scale = to_torch(global_scale, None) - - # block_scale_fp8 = block_scale.to(torch.float8_e4m3fn) - # block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8") - - global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") - block_scale = get_trt_tensor(ctx, block_scale, name + "_block_scale") - weights_tensor = get_trt_tensor(ctx, weights_tensor, name + "_weights_tensor") - - # quantize block scale to fp8 - block_scale_quantize_layer = ctx.net.add_quantize(block_scale, global_scale) - set_layer_name( - block_scale_quantize_layer, - target, - name + "_block_scale_quantize_to_fp8", - source_ir, - ) - block_scale_quantize_layer.set_output_type(0, trt.DataType.FP8) - block_scale_fp8 = block_scale_quantize_layer.get_output(0) - - # dequantize block scale from fp8 to float32 - dequantize_block_scale_layer = ctx.net.add_dequantize( - block_scale_fp8, - global_scale, - block_scale.dtype, - ) - set_layer_name( - dequantize_block_scale_layer, - target, - name + "_dequantize_block_scale_from_fp8", - source_ir, - ) - dequantized_block_scale = dequantize_block_scale_layer.get_output(0) - - # quantize weights tensor to fp4 - quantize_weights_layer = ctx.net.add_quantize( - weights_tensor, dequantized_block_scale - ) - set_layer_name( - quantize_weights_layer, - target, - name + "_quantize_weights_to_fp4", - source_ir, - ) - quantize_weights_layer.set_output_type(0, trt.DataType.FP4) - weights_fp4 = quantize_weights_layer.get_output(0) - - # dequantize weights tensor from fp4 to originaldtype(default is float32) - dequantize_weights_layer = ctx.net.add_dequantize( - weights_fp4, - dequantized_block_scale, - trt.DataType.FLOAT, - ) - set_layer_name( - dequantize_weights_layer, - target, - name + "_dequantize_weights_from_fp4", - source_ir, - ) - dequantized_data = dequantize_weights_layer.get_output(0) - return dequantized_data - - def _calculate_global_scale( ctx: ConversionContext, name: str, @@ -470,30 +251,3 @@ def _calculate_global_scale( global_scale = 1.0 return global_scale - -def _calculate_block_scale( - ctx: ConversionContext, - name: str, - weights_tensor: TRTTensor, - block_size: int, -) -> TRTTensor: - amax = weights_tensor.abs().amax().float() - # reference: https://gitlab-master.nvidia.com/omniml/modelopt/-/blob/main/modelopt/onnx/quantization/quant_utils.py#L122 - weights_scaling_factor_2 = amax / 6 / 448 - if weights_scaling_factor_2 == 0: - weights_scaling_factor_2 = 1.0 - - # reference: https://gitlab-master.nvidia.com/omniml/modelopt/-/blob/main/modelopt/onnx/quantization/quant_utils.py#L131 - [n, k] = weights_tensor.shape[-2:] - assert block_size != 0, "block_size must be non-zero" - assert k % block_size == 0, "k must be a multiple of block_size" - reshaped_input_tensor = weights_tensor.reshape( - tuple(weights_tensor.shape[:-2]) + (n, k // block_size, block_size) - ) - - per_block_amax = reshaped_input_tensor.abs().amax(dim=-1).float() - per_block_scale = torch.divide(per_block_amax, 6) - q_per_block_scale = torch.divide(per_block_scale, weights_scaling_factor_2) - # TODO:set all zero values in scale to 1.0 - # block_scale = get_trt_tensor(ctx, q_per_block_scale, name + "_block_scale") - return q_per_block_scale From 391c9710e05a35ce88c8a8e66f448e3778fe3a40 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 15 May 2025 10:27:27 -0700 Subject: [PATCH 16/30] test --- .../dynamo/conversion/converter_utils.py | 31 ++- .../dynamo/conversion/impl/quantize.py | 251 ------------------ 2 files changed, 18 insertions(+), 264 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 298687d38e..eb18a14eca 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -362,32 +362,37 @@ def create_constant( if torch_value is not None: if torch_value.dtype == torch.float8_e4m3fn: - weights = trt.Weights(type=trt.DataType.FP8, ptr=torch_value.data_ptr(), count=torch_value.numel()) + weights = trt.Weights( + type=trt.DataType.FP8, + ptr=torch_value.data_ptr(), + count=torch_value.numel(), + ) constant = ctx.net.add_constant( shape, weights, ) constant.name = name return constant.get_output(0) - - if torch_value.dtype in [torch.bfloat16]: - torch_value_fp32 = torch_value.to(torch.float32) - numpy_value = torch_value_fp32.numpy() - else: - numpy_value = torch_value.numpy() - ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1) - + # Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8 if torch_value.dtype == torch.uint8: - weights = trt.Weights(type=trt.DataType.FP4, ptr=numpy_value.ctypes.data, count=numpy_value.size * 2) - shape[1] = shape[1] * 2 + weights = trt.Weights( + type=trt.DataType.FP4, + ptr=torch_value.data_ptr(), + count=torch_value.numel() * 2, + ) + shape[-1] = shape[-1] * 2 constant = ctx.net.add_constant( shape, weights, ) constant.name = name return constant.get_output(0) - - + if torch_value.dtype == torch.bfloat16: + torch_value_fp32 = torch_value.to(torch.float32) + numpy_value = torch_value_fp32.numpy() + else: + numpy_value = torch_value.numpy() + ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1) constant = ctx.net.add_constant( shape, numpy_value, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py index 192f9c648a..e472ed3092 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -28,8 +28,6 @@ def quantize( """ with unset_fake_temporarily(): - if not isinstance(input_tensor, TRTTensor): - input_tensor = get_trt_tensor(ctx, input_tensor, name + "_quantize_input") if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in ( trt.float32, trt.float16, @@ -69,252 +67,3 @@ def quantize( dq_output = dequantize_layer.get_output(0) return dq_output - - -# def nvfp4_quantize( -# ctx: ConversionContext, -# target: Target, -# source_ir: Optional[SourceIR], -# name: str, -# input_tensor: TRTTensor, -# block_size: int, -# amax: Union[np.ndarray, torch.Tensor], -# num_bits: int, -# exponent_bits: int, -# scale_num_bits: int, -# scale_exponent_bits: int, -# ) -> TRTTensor: -# """ -# Adds quantize and dequantize ops (QDQ) which quantize to FP4 based -# on the output_type set and dequantizes them back. -# """ -# print( -# f"lan added nvfp4_quantize entered: {target=} {source_ir=} {name=} {input_tensor.shape=} {input_tensor.dtype=} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}" -# ) -# with unset_fake_temporarily(): -# if input_tensor.dtype not in ( -# trt.float32, -# trt.float16, -# trt.bfloat16, -# torch.float32, -# torch.float16, -# torch.bfloat16, -# ): -# raise ValueError( -# f"dynamic_block_quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16 | bfloat16" -# ) -# if len(input_tensor.shape) not in (2, 3): -# raise ValueError( -# f"dynamic_block_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D" -# ) -# axis = len(input_tensor.shape) - 1 - -# # TODO: ADD PADDING IF NEEDED -# # TODO: ADD DYNAMIC SHAPE SUPPORT - -# global_scale = _calculate_global_scale(ctx, name, amax) - -# if ".weight_quantizer" in name: -# block_scale = _calculate_block_scale( -# ctx, -# name, -# input_tensor, -# block_size, -# ) -# input_tensor = get_trt_tensor(ctx, input_tensor, name + "_input") -# output = _static_double_quantize( -# ctx, -# target, -# source_ir, -# name, -# input_tensor, -# block_scale, -# global_scale, -# ) -# elif ".input_quantizer" in name: -# # quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 -# output = _dynamic_double_quantize( -# ctx, -# target, -# source_ir, -# name, -# input_tensor, -# global_scale, -# ) - -# else: -# raise ValueError( -# f"quantizer received an input of {name}. Supported values: weight_quantizer | input_quantizer" -# ) -# return output - - -# def _dynamic_double_quantize( -# ctx: ConversionContext, -# target: Target, -# source_ir: Optional[SourceIR], -# name: str, -# input_tensor: TRTTensor, -# global_scale: TRTTensor, -# axis: int = -1, -# block_size: int = 16, -# output_type: trt.DataType = trt.DataType.FP4, -# scale_type: trt.DataType = trt.DataType.FP8, -# ) -> TRTTensor: -# """ -# quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 -# Parameters: -# ctx: ConversionContext, -# target: Target, -# source_ir: Optional[SourceIR] -# name: str -# input_tensor : Tensor (On GPU) -# The input tensor. -# global_scale : Tensor (On GPU) -# The global per-tensor scaling factor. It should contain only 1 element. -# axis : int -# The axis to quantize. Default is -1 (the last axis). -# block_size : int -# The block size for quantization. Default is 16. -# output_type : trt.DataType -# The data type for quantized data. Default is FP4. -# scale_type : trt.DataType -# The data type for block scale. Default is FP8. - -# """ -# # dynamic quantize input tensor to fp4 -# dynamic_quantize_layer = ctx.net.add_dynamic_quantize( -# input_tensor, -# axis, -# block_size, -# output_type, -# scale_type, -# ) -# dynamic_quantize_layer.set_input(1, global_scale) -# set_layer_name( -# dynamic_quantize_layer, target, name + "_dynamic_quantize", source_ir -# ) -# quantized_data_in_fp4 = dynamic_quantize_layer.get_output(0) -# quantized_scale_in_fp8 = dynamic_quantize_layer.get_output(1) - -# # dequantize scale from fp8 to orignal dtype(default is float32) -# dequantize_scale_layer = ctx.net.add_dequantize( -# quantized_scale_in_fp8, global_scale, input_tensor.dtype -# ) -# set_layer_name( -# dequantize_scale_layer, target, name + "_dequantize_scale", source_ir -# ) -# dequantized_scale = dequantize_scale_layer.get_output(0) - -# # dequantize quantized_data_in_fp4 from fp4 to orignal dtype(default is float32) -# dequantize_data_layer = ctx.net.add_dequantize( -# quantized_data_in_fp4, dequantized_scale, input_tensor.dtype -# ) -# dequantize_data_layer.axis = axis -# set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) -# dequantized_data = dequantize_data_layer.get_output(0) -# return dequantized_data - - -# def _static_double_quantize( -# ctx: ConversionContext, -# target: Target, -# source_ir: Optional[SourceIR], -# name: str, -# input_tensor: TRTTensor, -# block_scale: TRTTensor, -# global_scale: TRTTensor, -# ) -> TRTTensor: -# """ -# Parameters: -# ctx: ConversionContext, -# target: Target, -# source_ir: Optional[SourceIR], -# name: str, -# input_tensor : Tensor (On GPU) -# The input tensor. -# block_scale : Tensor (On GPU) -# The per-block scaling factor. -# global_scale : Tensor (On GPU) -# The global per-tensor scaling factor. It should contain only 1 element. -# Returns: -# A tuple of two tensors: quantized data tensor in fp4 and quantized block scaling factor tensor in fp8 -# """ -# # quantize block scale to fp8 -# block_scale_quantize_layer = ctx.net.add_quantize(block_scale, global_scale) -# set_layer_name( -# block_scale_quantize_layer, -# target, -# name + "_block_scale_quantize", -# source_ir, -# ) -# block_scale_quantize_layer.set_output_type(0, trt.DataType.FP8) -# quantized_block_scale_in_fp8 = block_scale_quantize_layer.get_output(0) - -# # dequantize block scale from fp8 to original dtype(default is float32) -# dequantize_block_scale_layer = ctx.net.add_dequantize( -# quantized_block_scale_in_fp8, -# global_scale, -# block_scale.dtype, -# ) -# set_layer_name( -# dequantize_block_scale_layer, -# target, -# name + "_dequantize_block_scale", -# source_ir, -# ) -# dequantized_block_scale = dequantize_block_scale_layer.get_output(0) - -# # quantize input tensor to fp4 -# data_quantize_layer = ctx.net.add_quantize(input_tensor, dequantized_block_scale) -# set_layer_name(data_quantize_layer, target, name + "_data_quantize", source_ir) -# data_quantize_layer.set_output_type(0, trt.DataType.FP4) -# quantized_data_in_fp4 = data_quantize_layer.get_output(0) - -# # dequantize input tensor from fp4 to originaldtype(default is float32) -# dequantize_data_layer = ctx.net.add_dequantize( -# quantized_data_in_fp4, -# dequantized_block_scale, -# input_tensor.dtype, -# ) -# set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) -# dequantized_data = dequantize_data_layer.get_output(0) -# return dequantized_data - - -# def _calculate_global_scale( -# ctx: ConversionContext, -# name: str, -# amax: TRTTensor, -# ) -> TRTTensor: -# # calculate global scale (the global per-tensor scaling factor, should only contain 1 element) -# if amax is None or amax == 0: -# amax = 1.0 -# amax = to_torch( -# amax, None -# ) # amax is calculated from input_tensor.abs().amax().float() -# global_scale = torch.divide(amax, 6) # 6*448 -# global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") -# return global_scale - - -# def _calculate_block_scale( -# ctx: ConversionContext, -# name: str, -# input_tensor: TRTTensor, -# block_size: int, -# ) -> TRTTensor: - -# [n, k] = input_tensor.shape[-2:] -# assert block_size != 0, "block_size must be non-zero" -# assert k % block_size == 0, "k must be a multiple of block_size" -# reshaped_input_tensor = input_tensor.reshape( -# tuple(input_tensor.shape[:-2]) + (n, k // block_size, block_size) -# ) -# amax = input_tensor.abs().amax().float() -# amax = torch.divide(amax, 6*448) -# block_amax = reshaped_input_tensor.abs().amax(dim=-1).float() -# block_scale = torch.divide(block_amax, 6) -# block_scale = torch.divide(block_scale, amax) -# block_scale = get_trt_tensor(ctx, block_scale, name + "_block_scale") -# return block_scale From 38297bd41ac14c1e0cbc99bf696998349c32b688 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 15 May 2025 14:33:15 -0700 Subject: [PATCH 17/30] test --- .../dynamo/conversion/_TRTInterpreter.py | 2 ++ tests/py/dynamo/models/test_models_export.py | 14 +++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index da5f3b36c9..292fe0d637 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -283,6 +283,8 @@ def _populate_trt_builder_config( if dtype.fp8 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.FP8) + if dtype.fp8 in self.compilation_settings.enabled_precisions: + builder_config.set_flag(trt.BuilderFlag.FP4) if dtype.bfloat16 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.BF16) diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 52d746f995..a8b29db6b5 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -233,19 +233,31 @@ def calibrate_loop(model): mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) # model has qdq nodes at this point output_pyt = model(input_tensor) + torch.onnx.export(model, input_tensor, "mtq_model.onnx") with torch.no_grad(): with export_torch_mode(): exp_program = torch.export.export(model, (input_tensor,), strict=False) + from torch.fx import passes + + g = passes.graph_drawer.FxGraphDrawer(exp_program, "torch_export_fp4") + with open("a.svg", "wb") as f: + f.write(g.get_dot_graph().create_svg()) + trt_model = torchtrt.dynamo.compile( exp_program, inputs=[input_tensor], - enabled_precisions={torch.float4_e2m1fn_x2}, + enabled_precisions={ + torch.float4_e2m1fn_x2, + torch.float32, + torch.float16, + }, min_block_size=1, debug=True, cache_built_engines=False, reuse_cached_engines=False, ) + outputs_trt = trt_model(input_tensor) print(f"lan added outputs_trt: {outputs_trt}") print(f"lan added output_pyt: {output_pyt}") From 095251f22d51e598911742526963013a8fc2d761 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 15 May 2025 15:36:13 -0700 Subject: [PATCH 18/30] add print graph --- tests/py/dynamo/models/test_models_export.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index a8b29db6b5..a08ff02f4a 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -215,7 +215,7 @@ def test_base_fp4(ir): class SimpleNetwork(torch.nn.Module): def __init__(self): super(SimpleNetwork, self).__init__() - self.linear1 = torch.nn.Linear(in_features=16, out_features=3) + self.linear1 = torch.nn.Linear(in_features=64, out_features=32) def forward(self, x): x = self.linear1(x) @@ -225,7 +225,7 @@ def calibrate_loop(model): """Simple calibration function for testing.""" model(input_tensor) - input_tensor = torch.randn(5, 16).cuda() + input_tensor = torch.randn(128, 64).cuda() print(f"lan added amax: {input_tensor.abs().amax()}") model = SimpleNetwork().eval().cuda() @@ -259,8 +259,10 @@ def calibrate_loop(model): ) outputs_trt = trt_model(input_tensor) - print(f"lan added outputs_trt: {outputs_trt}") - print(f"lan added output_pyt: {output_pyt}") + print(f"lan added torch_tensorrtoutputs_trt: {outputs_trt}") + print(f"lan added pytorchoutput_pyt: {output_pyt}") + abs_diff = torch.abs(output_pyt - outputs_trt) + print(f"lan added max abs_diff: {abs_diff.max().item()}") assert torch.allclose(output_pyt, outputs_trt, rtol=4e-1, atol=4e-1) From 58302110e8d6c4fcd632099b7025241d35f3bf27 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Fri, 16 May 2025 08:53:24 -0700 Subject: [PATCH 19/30] test float16 --- examples/dynamo/vgg16_ptq.py | 2 ++ .../dynamo/conversion/_TRTInterpreter.py | 2 +- .../dynamo/conversion/impl/nvfp4_quantize.py | 29 ++++++++++++------- tests/py/dynamo/models/test_models_export.py | 4 +-- 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/examples/dynamo/vgg16_ptq.py b/examples/dynamo/vgg16_ptq.py index 0ed8772a44..1cccb3a0a1 100644 --- a/examples/dynamo/vgg16_ptq.py +++ b/examples/dynamo/vgg16_ptq.py @@ -204,6 +204,7 @@ def calibrate_loop(model): quant_cfg = mtq.NVFP4_DEFAULT_CFG # PTQ with in-place replacement to quantized modules mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has FP8 qdq nodes at this point # %% @@ -235,6 +236,7 @@ def calibrate_loop(model): with export_torch_mode(): # Compile the model with Torch-TensorRT Dynamo backend input_tensor = images.cuda() + torch.onnx.export(model, input_tensor, "mtq_vgg16_model.onnx") exp_program = torch.export.export(model, (input_tensor,), strict=False) if args.quantize_type == "int8": diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 292fe0d637..c51e56fe9b 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -283,7 +283,7 @@ def _populate_trt_builder_config( if dtype.fp8 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.FP8) - if dtype.fp8 in self.compilation_settings.enabled_precisions: + if dtype.fp4 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.FP4) if dtype.bfloat16 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.BF16) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py index 1b43f98181..1fc7d2a6c3 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py @@ -8,6 +8,7 @@ from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_trt_tensor, get_trt_tensor, to_torch, ) @@ -43,7 +44,8 @@ def nvfp4_quantize( axis = len(input_tensor.shape) - 1 global_scale = _calculate_global_scale(ctx, name, amax) if ".weight_quantizer" in name: - _test_weights_scaling_factor(input_tensor, global_scale) + # _test_weights_scaling_factor(input_tensor, global_scale) + input_tensor = input_tensor.to(torch.float16) output = _static_double_quantize( ctx, target, @@ -54,6 +56,9 @@ def nvfp4_quantize( axis, ) elif ".input_quantizer" in name: + input_tensor = cast_trt_tensor( + ctx, input_tensor, trt.float16, name + "_input_tensor_f16" + ) # quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 output = _dynamic_double_quantize( ctx, @@ -85,7 +90,7 @@ def _dynamic_double_quantize( scale_type: trt.DataType = trt.DataType.FP8, ) -> TRTTensor: """ - quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 + quantize input tensor to fp4 Parameters: ctx: ConversionContext, target: Target, @@ -125,6 +130,7 @@ def _dynamic_double_quantize( dequantize_scale_layer = ctx.net.add_dequantize( quantized_scale_in_fp8, global_scale, input_tensor.dtype ) + dequantize_scale_layer.axis = axis set_layer_name( dequantize_scale_layer, target, name + "_dequantize_scale", source_ir ) @@ -142,8 +148,7 @@ def _dynamic_double_quantize( # TODO: to remove it this is to make sure our global scale and block scale calculation is correct during debugging def _test_weights_scaling_factor( - weights_tensor: torch.Tensor, - global_scale: torch.Tensor + weights_tensor: torch.Tensor, global_scale: torch.Tensor ) -> None: import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor @@ -194,10 +199,12 @@ def _static_double_quantize( import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor block_scale_fp8 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( - weights_tensor, 16, global_scale, + weights_tensor, + 16, + global_scale, )[0] - weights_tensor_scaled = nvfp4_tensor.NVFP4QTensor.quantize( + weights_tensor_fp4 = nvfp4_tensor.NVFP4QTensor.quantize( weights_tensor, 16, block_scale_fp8, @@ -207,7 +214,7 @@ def _static_double_quantize( block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8") global_scale = to_torch(global_scale, None) global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") - weights_fp4_represented_in_uint8 = get_trt_tensor(ctx, weights_tensor_scaled, name + "_weights_fp4_represented_in_uint8") + weights_tensor_fp4 = get_trt_tensor(ctx, weights_tensor_fp4, name + "_weights_fp4") # dequantize block scale from fp8 to float32 dequantize_block_scale_layer = ctx.net.add_dequantize( @@ -215,6 +222,8 @@ def _static_double_quantize( global_scale, trt.DataType.FLOAT, ) + dequantize_block_scale_layer.axis = axis + dequantize_block_scale_layer.precision = trt.DataType.FP8 set_layer_name( dequantize_block_scale_layer, target, @@ -225,10 +234,11 @@ def _static_double_quantize( # dequantize weights tensor from fp4 to originaldtype(default is float32) dequantize_data_layer = ctx.net.add_dequantize( - weights_fp4_represented_in_uint8, + weights_tensor_fp4, dequantized_block_scale, - trt.DataType.FLOAT, + trt.DataType.HALF, ) + dequantize_data_layer.axis = axis dequantize_data_layer.precision = trt.DataType.FP4 set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) dequantized_data = dequantize_data_layer.get_output(0) @@ -250,4 +260,3 @@ def _calculate_global_scale( if global_scale == 0: global_scale = 1.0 return global_scale - diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index a08ff02f4a..1a49cb055b 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -259,8 +259,8 @@ def calibrate_loop(model): ) outputs_trt = trt_model(input_tensor) - print(f"lan added torch_tensorrtoutputs_trt: {outputs_trt}") - print(f"lan added pytorchoutput_pyt: {output_pyt}") + print(f"lan added torch_tensorrt outputs_trt: {outputs_trt}") + print(f"lan added pytorch output_pyt: {output_pyt}") abs_diff = torch.abs(output_pyt - outputs_trt) print(f"lan added max abs_diff: {abs_diff.max().item()}") assert torch.allclose(output_pyt, outputs_trt, rtol=4e-1, atol=4e-1) From 8f57c86f39b3091d8ba8931733057220ed42176a Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Fri, 16 May 2025 10:17:25 -0700 Subject: [PATCH 20/30] change to float16 --- .../dynamo/conversion/impl/nvfp4_quantize.py | 18 ++++++++++++------ tests/py/dynamo/models/test_models_export.py | 6 ++++-- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py index 1fc7d2a6c3..69284d1edc 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py @@ -45,7 +45,6 @@ def nvfp4_quantize( global_scale = _calculate_global_scale(ctx, name, amax) if ".weight_quantizer" in name: # _test_weights_scaling_factor(input_tensor, global_scale) - input_tensor = input_tensor.to(torch.float16) output = _static_double_quantize( ctx, target, @@ -56,9 +55,6 @@ def nvfp4_quantize( axis, ) elif ".input_quantizer" in name: - input_tensor = cast_trt_tensor( - ctx, input_tensor, trt.float16, name + "_input_tensor_f16" - ) # quantize input tensor to fp4, output should be data tensor in fp4 and block scale tensor in fp8 output = _dynamic_double_quantize( ctx, @@ -111,6 +107,9 @@ def _dynamic_double_quantize( """ global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + + if input_tensor.dtype not in [trt.DataType.HALF, trt.DataType.FLOAT]: + raise ValueError(f"Currently try float16, float32 only on input tensor for now. Unsupported dtype: {input_tensor.dtype}") # dynamic quantize input tensor to fp4 dynamic_quantize_layer = ctx.net.add_dynamic_quantize( input_tensor, @@ -197,6 +196,13 @@ def _static_double_quantize( """ import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor + + if weights_tensor.dtype == torch.float16: + original_dtype = trt.DataType.HALF + elif weights_tensor.dtype == torch.float32: + original_dtype = trt.DataType.FLOAT + else: + raise ValueError(f"Currently try float16, float32 only on weights tensor. Unsupported dtype: {weights_tensor.dtype}") block_scale_fp8 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( weights_tensor, @@ -220,7 +226,7 @@ def _static_double_quantize( dequantize_block_scale_layer = ctx.net.add_dequantize( block_scale_fp8, global_scale, - trt.DataType.FLOAT, + original_dtype, ) dequantize_block_scale_layer.axis = axis dequantize_block_scale_layer.precision = trt.DataType.FP8 @@ -236,7 +242,7 @@ def _static_double_quantize( dequantize_data_layer = ctx.net.add_dequantize( weights_tensor_fp4, dequantized_block_scale, - trt.DataType.HALF, + original_dtype, ) dequantize_data_layer.axis = axis dequantize_data_layer.precision = trt.DataType.FP4 diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 1a49cb055b..88b51bb1b3 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -215,7 +215,7 @@ def test_base_fp4(ir): class SimpleNetwork(torch.nn.Module): def __init__(self): super(SimpleNetwork, self).__init__() - self.linear1 = torch.nn.Linear(in_features=64, out_features=32) + self.linear1 = torch.nn.Linear(in_features=64, out_features=32, bias=False, dtype=torch.float16) def forward(self, x): x = self.linear1(x) @@ -225,7 +225,8 @@ def calibrate_loop(model): """Simple calibration function for testing.""" model(input_tensor) - input_tensor = torch.randn(128, 64).cuda() + input_tensor = torch.randn(128, 64, dtype=torch.float16).cuda() + breakpoint() print(f"lan added amax: {input_tensor.abs().amax()}") model = SimpleNetwork().eval().cuda() @@ -249,6 +250,7 @@ def calibrate_loop(model): inputs=[input_tensor], enabled_precisions={ torch.float4_e2m1fn_x2, + torch.float8_e4m3fn, torch.float32, torch.float16, }, From 24d060280c5983ab3a7bc1cc1840cdc78fab7b03 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sun, 18 May 2025 08:55:05 -0700 Subject: [PATCH 21/30] upgrade to 10.10.0 for tensorrt --- MODULE.bazel | 4 +-- examples/dynamo/vgg16_ptq.py | 8 +++-- .../dynamo/conversion/impl/nvfp4_quantize.py | 30 ++++++++++++++----- pyproject.toml | 10 +++---- tests/py/dynamo/models/test_models_export.py | 6 ++-- 5 files changed, 39 insertions(+), 19 deletions(-) diff --git a/MODULE.bazel b/MODULE.bazel index 82bb4ba79e..17baec14e9 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -72,9 +72,9 @@ http_archive( http_archive( name = "tensorrt", build_file = "@//third_party/tensorrt/archive:BUILD", - strip_prefix = "TensorRT-10.9.0.34", + strip_prefix = "TensorRT-10.10.0.31", urls = [ - "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.9.0/tars/TensorRT-10.9.0.34.Linux.x86_64-gnu.cuda-12.8.tar.gz", + "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.10.0/tars/TensorRT-10.10.0.31.Linux.x86_64-gnu.cuda-12.9.tar.gz", ], ) diff --git a/examples/dynamo/vgg16_ptq.py b/examples/dynamo/vgg16_ptq.py index 1cccb3a0a1..7af490afb4 100644 --- a/examples/dynamo/vgg16_ptq.py +++ b/examples/dynamo/vgg16_ptq.py @@ -169,7 +169,6 @@ def vgg16(num_classes=1000, init_weights=False): data = iter(training_dataloader) images, _ = next(data) - crit = nn.CrossEntropyLoss() # %% @@ -244,7 +243,12 @@ def calibrate_loop(model): elif args.quantize_type == "fp8": enabled_precisions = {torch.float8_e4m3fn} elif args.quantize_type == "fp4": - enabled_precisions = {torch.float4_e2m1fn_x2} + enabled_precisions = { + torch.float4_e2m1fn_x2, + torch.float8_e4m3fn, + torch.float16, + torch.float32, + } trt_model = torchtrt.dynamo.compile( exp_program, inputs=[input_tensor], diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py index 69284d1edc..a610228379 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py @@ -8,7 +8,6 @@ from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( - cast_trt_tensor, get_trt_tensor, to_torch, ) @@ -43,6 +42,8 @@ def nvfp4_quantize( with unset_fake_temporarily(): axis = len(input_tensor.shape) - 1 global_scale = _calculate_global_scale(ctx, name, amax) + print(f"lan added global_scale: {input_tensor.shape=} {input_tensor.dtype=}") + print(f"lan added global_scale: {global_scale.shape=} {global_scale.dtype=}") if ".weight_quantizer" in name: # _test_weights_scaling_factor(input_tensor, global_scale) output = _static_double_quantize( @@ -109,7 +110,9 @@ def _dynamic_double_quantize( global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") if input_tensor.dtype not in [trt.DataType.HALF, trt.DataType.FLOAT]: - raise ValueError(f"Currently try float16, float32 only on input tensor for now. Unsupported dtype: {input_tensor.dtype}") + raise ValueError( + f"Currently try float16, float32 only on input tensor for now. Unsupported dtype: {input_tensor.dtype}" + ) # dynamic quantize input tensor to fp4 dynamic_quantize_layer = ctx.net.add_dynamic_quantize( input_tensor, @@ -130,16 +133,18 @@ def _dynamic_double_quantize( quantized_scale_in_fp8, global_scale, input_tensor.dtype ) dequantize_scale_layer.axis = axis + dequantize_scale_layer.set_output_type(0, input_tensor.dtype) set_layer_name( dequantize_scale_layer, target, name + "_dequantize_scale", source_ir ) dequantized_scale = dequantize_scale_layer.get_output(0) - # dequantize quantized_data_in_fp4 from fp4 to orignal dtype(default is float32) + # dequantize quantized_data_in_fp4 from fp4 to orignal dtype(default is float32) dequantize_data_layer = ctx.net.add_dequantize( quantized_data_in_fp4, dequantized_scale, input_tensor.dtype ) dequantize_data_layer.axis = axis + dequantize_scale_layer.set_output_type(0, input_tensor.dtype) set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) dequantized_data = dequantize_data_layer.get_output(0) return dequantized_data @@ -196,13 +201,15 @@ def _static_double_quantize( """ import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor - + if weights_tensor.dtype == torch.float16: original_dtype = trt.DataType.HALF elif weights_tensor.dtype == torch.float32: original_dtype = trt.DataType.FLOAT else: - raise ValueError(f"Currently try float16, float32 only on weights tensor. Unsupported dtype: {weights_tensor.dtype}") + raise ValueError( + f"Currently try float16, float32 only on weights tensor. Unsupported dtype: {weights_tensor.dtype}" + ) block_scale_fp8 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( weights_tensor, @@ -221,8 +228,7 @@ def _static_double_quantize( global_scale = to_torch(global_scale, None) global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") weights_tensor_fp4 = get_trt_tensor(ctx, weights_tensor_fp4, name + "_weights_fp4") - - # dequantize block scale from fp8 to float32 + # dequantize block scale from fp8 to original dtype (default is float32) dequantize_block_scale_layer = ctx.net.add_dequantize( block_scale_fp8, global_scale, @@ -230,15 +236,19 @@ def _static_double_quantize( ) dequantize_block_scale_layer.axis = axis dequantize_block_scale_layer.precision = trt.DataType.FP8 + dequantize_block_scale_layer.set_output_type(0, original_dtype) set_layer_name( dequantize_block_scale_layer, target, name + "_dequantize_block_scale", source_ir, ) + print( + f"lan added dequantize_block_scale_layer: {dequantize_block_scale_layer.axis=} {dequantize_block_scale_layer.precision=} {dequantize_block_scale_layer.get_output_type(0)=}" + ) dequantized_block_scale = dequantize_block_scale_layer.get_output(0) - # dequantize weights tensor from fp4 to originaldtype(default is float32) + # dequantize weights tensor from fp4 to original dtype(default is float32) dequantize_data_layer = ctx.net.add_dequantize( weights_tensor_fp4, dequantized_block_scale, @@ -246,7 +256,11 @@ def _static_double_quantize( ) dequantize_data_layer.axis = axis dequantize_data_layer.precision = trt.DataType.FP4 + dequantize_data_layer.set_output_type(0, original_dtype) set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) + print( + f"lan added dequantize_data_layer: {dequantize_data_layer.axis=} {dequantize_data_layer.precision=} {dequantize_data_layer.get_output_type(0)=}" + ) dequantized_data = dequantize_data_layer.get_output(0) return dequantized_data diff --git a/pyproject.toml b/pyproject.toml index 3bb857e3e0..e2878bc7bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires = [ "cffi>=1.15.1", "typing-extensions>=4.7.0", "future>=0.18.3", - "tensorrt-cu12>=10.9.0,<10.10.0", + "tensorrt-cu12>=10.10.0,<10.11.0", "torch>=2.8.0.dev,<2.9.0", "pybind11==2.6.2", "numpy", @@ -56,10 +56,10 @@ keywords = [ ] dependencies = [ "torch>=2.8.0.dev,<2.9.0", - "tensorrt>=10.9.0,<10.10.0", - "tensorrt-cu12>=10.9.0,<10.10.0", - "tensorrt-cu12-bindings>=10.9.0,<10.10.0", - "tensorrt-cu12-libs>=10.9.0,<10.10.0", + "tensorrt>=10.10.0,<10.11.0", + "tensorrt-cu12>=10.10.0,<10.11.0", + "tensorrt-cu12-bindings>=10.10.0,<10.11.0", + "tensorrt-cu12-libs>=10.10.0,<10.11.0", "packaging>=23", "numpy", "typing-extensions>=4.7.0", diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 88b51bb1b3..862f70a719 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -215,7 +215,9 @@ def test_base_fp4(ir): class SimpleNetwork(torch.nn.Module): def __init__(self): super(SimpleNetwork, self).__init__() - self.linear1 = torch.nn.Linear(in_features=64, out_features=32, bias=False, dtype=torch.float16) + self.linear1 = torch.nn.Linear( + in_features=64, out_features=32, bias=True, dtype=torch.float16 + ) def forward(self, x): x = self.linear1(x) @@ -226,7 +228,7 @@ def calibrate_loop(model): model(input_tensor) input_tensor = torch.randn(128, 64, dtype=torch.float16).cuda() - breakpoint() + print(f"lan added amax: {input_tensor.abs().amax()}") model = SimpleNetwork().eval().cuda() From 58d158c5f40b1636e7175daf3fdba27796eb93d1 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sun, 18 May 2025 10:33:15 -0700 Subject: [PATCH 22/30] use strongly typed network --- py/torch_tensorrt/dynamo/_compiler.py | 14 ++++----- .../dynamo/conversion/_TRTInterpreter.py | 30 +++++++++---------- .../dynamo/conversion/impl/nvfp4_quantize.py | 17 ++++++----- tests/py/dynamo/models/test_models_export.py | 1 + 4 files changed, 32 insertions(+), 30 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index acd16a32f0..5e6f4c20c4 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -575,13 +575,13 @@ def compile( "\nThis feature is unimplemented in Torch-TRT Dynamo currently." ) - if use_explicit_typing: - if len(enabled_precisions) != 1 or not any( - x in enabled_precisions for x in {torch.float32, dtype.f32} - ): - raise AssertionError( - f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}" - ) + # if use_explicit_typing: + # if len(enabled_precisions) != 1 or not any( + # x in enabled_precisions for x in {torch.float32, dtype.f32} + # ): + # raise AssertionError( + # f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}" + # ) if use_fp32_acc: logger.debug( diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index c51e56fe9b..cc63f7d271 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -274,25 +274,25 @@ def _populate_trt_builder_config( self.compilation_settings.dla_global_dram_size, ) - if dtype.float16 in self.compilation_settings.enabled_precisions: - builder_config.set_flag(trt.BuilderFlag.FP16) + # if dtype.float16 in self.compilation_settings.enabled_precisions: + # builder_config.set_flag(trt.BuilderFlag.FP16) - if dtype.int8 in self.compilation_settings.enabled_precisions: - builder_config.set_flag(trt.BuilderFlag.INT8) + # if dtype.int8 in self.compilation_settings.enabled_precisions: + # builder_config.set_flag(trt.BuilderFlag.INT8) - if dtype.fp8 in self.compilation_settings.enabled_precisions: - builder_config.set_flag(trt.BuilderFlag.FP8) + # if dtype.fp8 in self.compilation_settings.enabled_precisions: + # builder_config.set_flag(trt.BuilderFlag.FP8) - if dtype.fp4 in self.compilation_settings.enabled_precisions: - builder_config.set_flag(trt.BuilderFlag.FP4) - if dtype.bfloat16 in self.compilation_settings.enabled_precisions: - builder_config.set_flag(trt.BuilderFlag.BF16) + # if dtype.fp4 in self.compilation_settings.enabled_precisions: + # builder_config.set_flag(trt.BuilderFlag.FP4) + # if dtype.bfloat16 in self.compilation_settings.enabled_precisions: + # builder_config.set_flag(trt.BuilderFlag.BF16) - if self.compilation_settings.sparse_weights: - builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS) + # if self.compilation_settings.sparse_weights: + # builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS) - if self.compilation_settings.disable_tf32: - builder_config.clear_flag(trt.BuilderFlag.TF32) + # if self.compilation_settings.disable_tf32: + # builder_config.clear_flag(trt.BuilderFlag.TF32) if self.compilation_settings.immutable_weights: # non-refittable engine @@ -353,7 +353,7 @@ def _populate_trt_builder_config( builder_config.l2_limit_for_tiling = ( self.compilation_settings.l2_limit_for_tiling ) - + print(f"lan added builder_config:{builder_config=}") return builder_config def _create_timing_cache( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py index a610228379..caf984897a 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py @@ -133,7 +133,8 @@ def _dynamic_double_quantize( quantized_scale_in_fp8, global_scale, input_tensor.dtype ) dequantize_scale_layer.axis = axis - dequantize_scale_layer.set_output_type(0, input_tensor.dtype) + dequantize_scale_layer.to_type = input_tensor.dtype + # dequantize_scale_layer.set_output_type(0, input_tensor.dtype) set_layer_name( dequantize_scale_layer, target, name + "_dequantize_scale", source_ir ) @@ -144,7 +145,7 @@ def _dynamic_double_quantize( quantized_data_in_fp4, dequantized_scale, input_tensor.dtype ) dequantize_data_layer.axis = axis - dequantize_scale_layer.set_output_type(0, input_tensor.dtype) + dequantize_data_layer.to_type = input_tensor.dtype set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) dequantized_data = dequantize_data_layer.get_output(0) return dequantized_data @@ -235,8 +236,8 @@ def _static_double_quantize( original_dtype, ) dequantize_block_scale_layer.axis = axis - dequantize_block_scale_layer.precision = trt.DataType.FP8 - dequantize_block_scale_layer.set_output_type(0, original_dtype) + # dequantize_block_scale_layer.precision = trt.DataType.FP8 + dequantize_block_scale_layer.to_type = original_dtype set_layer_name( dequantize_block_scale_layer, target, @@ -244,7 +245,7 @@ def _static_double_quantize( source_ir, ) print( - f"lan added dequantize_block_scale_layer: {dequantize_block_scale_layer.axis=} {dequantize_block_scale_layer.precision=} {dequantize_block_scale_layer.get_output_type(0)=}" + f"lan added dequantize_block_scale_layer: {dequantize_block_scale_layer.to_type=} {dequantize_block_scale_layer.axis=} {dequantize_block_scale_layer.precision=} {dequantize_block_scale_layer.get_output_type(0)=}" ) dequantized_block_scale = dequantize_block_scale_layer.get_output(0) @@ -255,11 +256,11 @@ def _static_double_quantize( original_dtype, ) dequantize_data_layer.axis = axis - dequantize_data_layer.precision = trt.DataType.FP4 - dequantize_data_layer.set_output_type(0, original_dtype) + # dequantize_data_layer.precision = trt.DataType.FP4 + dequantize_data_layer.to_type = original_dtype set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) print( - f"lan added dequantize_data_layer: {dequantize_data_layer.axis=} {dequantize_data_layer.precision=} {dequantize_data_layer.get_output_type(0)=}" + f"lan added dequantize_data_layer: {dequantize_data_layer.to_type=} {dequantize_data_layer.axis=} {dequantize_data_layer.precision=} {dequantize_data_layer.get_output_type(0)=}" ) dequantized_data = dequantize_data_layer.get_output(0) return dequantized_data diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 862f70a719..0e3f86b994 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -260,6 +260,7 @@ def calibrate_loop(model): debug=True, cache_built_engines=False, reuse_cached_engines=False, + use_explicit_typing=True, ) outputs_trt = trt_model(input_tensor) From 5a622aeabcc53db763f96e900aa7253c6e8f3d72 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sun, 18 May 2025 10:53:54 -0700 Subject: [PATCH 23/30] test --- tests/py/dynamo/models/test_models_export.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 0e3f86b994..7dc3925d98 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -231,11 +231,13 @@ def calibrate_loop(model): print(f"lan added amax: {input_tensor.abs().amax()}") model = SimpleNetwork().eval().cuda() + output_pyt = model(input_tensor) + print(f"lan added pytorch output_pyt: {output_pyt}") quant_cfg = mtq.NVFP4_DEFAULT_CFG mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) # model has qdq nodes at this point - output_pyt = model(input_tensor) + torch.onnx.export(model, input_tensor, "mtq_model.onnx") with torch.no_grad(): @@ -265,7 +267,6 @@ def calibrate_loop(model): outputs_trt = trt_model(input_tensor) print(f"lan added torch_tensorrt outputs_trt: {outputs_trt}") - print(f"lan added pytorch output_pyt: {output_pyt}") abs_diff = torch.abs(output_pyt - outputs_trt) print(f"lan added max abs_diff: {abs_diff.max().item()}") assert torch.allclose(output_pyt, outputs_trt, rtol=4e-1, atol=4e-1) From 8880b14e01d835f59d8ed5743ae66b49dc8e716f Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sun, 18 May 2025 14:18:23 -0700 Subject: [PATCH 24/30] print out internal weight scaling value --- .../dynamo/conversion/impl/nvfp4_quantize.py | 6 +++--- tests/py/dynamo/models/test_models_export.py | 9 +++++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py index caf984897a..7495a5bead 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py @@ -211,20 +211,20 @@ def _static_double_quantize( raise ValueError( f"Currently try float16, float32 only on weights tensor. Unsupported dtype: {weights_tensor.dtype}" ) - block_scale_fp8 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( weights_tensor, 16, global_scale, )[0] - + print(f"lan added global_scale: {global_scale.shape=} {global_scale.dtype=} {global_scale=}") + print(f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}") weights_tensor_fp4 = nvfp4_tensor.NVFP4QTensor.quantize( weights_tensor, 16, block_scale_fp8, global_scale, )[0]._quantized_data - + print(f"lan added weights_tensor_fp4: {weights_tensor_fp4.shape=} {weights_tensor_fp4.dtype=} {weights_tensor_fp4=}") block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8") global_scale = to_torch(global_scale, None) global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 7dc3925d98..c42e30a460 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -227,12 +227,17 @@ def calibrate_loop(model): """Simple calibration function for testing.""" model(input_tensor) - input_tensor = torch.randn(128, 64, dtype=torch.float16).cuda() + input_tensor = torch.ones(128, 64, dtype=torch.float16).cuda() print(f"lan added amax: {input_tensor.abs().amax()}") model = SimpleNetwork().eval().cuda() + model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=torch.float16).cuda()) + model.linear1.bias = torch.nn.Parameter(torch.zeros(128, 32, dtype=torch.float16).cuda()) output_pyt = model(input_tensor) - print(f"lan added pytorch output_pyt: {output_pyt}") + print(f"lan added model input: {input_tensor=}") + print(f"lan added model weight: {model.linear1.weight=}") + print(f"lan added model bias: {model.linear1.bias=}") + print(f"lan added pytorch output_pyt: {output_pyt} {output_pyt.dtype=} {output_pyt.shape=}") quant_cfg = mtq.NVFP4_DEFAULT_CFG mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) From bae918810fe5070ee1e22aa32fb6799e9855c8a1 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 21 May 2025 14:04:47 -0700 Subject: [PATCH 25/30] add disable flag --- py/torch_tensorrt/dynamo/_compiler.py | 12 ++++++------ .../dynamo/conversion/_TRTInterpreter.py | 13 +++++++------ .../dynamo/conversion/impl/nvfp4_quantize.py | 11 ++++++++--- tests/py/dynamo/models/test_models_export.py | 5 ++--- 4 files changed, 23 insertions(+), 18 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 4db11daa78..c4bc7238c9 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -582,12 +582,12 @@ def compile( ) # if use_explicit_typing: - # if len(enabled_precisions) != 1 or not any( - # x in enabled_precisions for x in {torch.float32, dtype.f32} - # ): - # raise AssertionError( - # f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}" - # ) + # if len(enabled_precisions) != 1 or not any( + # x in enabled_precisions for x in {torch.float32, dtype.f32} + # ): + # raise AssertionError( + # f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}" + # ) if use_fp32_acc: logger.debug( diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 8a9994e901..a9e7d194e1 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -87,6 +87,7 @@ def __init__( flag = 0 if compilation_settings.use_explicit_typing: + _LOGGER.info("Using strongly typed network definition") STRONGLY_TYPED = 1 << (int)( trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED ) @@ -277,8 +278,8 @@ def _populate_trt_builder_config( # if dtype.float16 in self.compilation_settings.enabled_precisions: # builder_config.set_flag(trt.BuilderFlag.FP16) - # if dtype.int8 in self.compilation_settings.enabled_precisions: - # builder_config.set_flag(trt.BuilderFlag.INT8) + if dtype.int8 in self.compilation_settings.enabled_precisions: + builder_config.set_flag(trt.BuilderFlag.INT8) # if dtype.fp8 in self.compilation_settings.enabled_precisions: # builder_config.set_flag(trt.BuilderFlag.FP8) @@ -288,11 +289,11 @@ def _populate_trt_builder_config( # if dtype.bfloat16 in self.compilation_settings.enabled_precisions: # builder_config.set_flag(trt.BuilderFlag.BF16) - # if self.compilation_settings.sparse_weights: - # builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS) + if self.compilation_settings.sparse_weights: + builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS) - # if self.compilation_settings.disable_tf32: - # builder_config.clear_flag(trt.BuilderFlag.TF32) + if self.compilation_settings.disable_tf32: + builder_config.clear_flag(trt.BuilderFlag.TF32) if self.compilation_settings.immutable_weights: # non-refittable engine diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py index 7495a5bead..2f052969f7 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py @@ -13,7 +13,7 @@ ) from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor - +import os def nvfp4_quantize( ctx: ConversionContext, @@ -42,7 +42,7 @@ def nvfp4_quantize( with unset_fake_temporarily(): axis = len(input_tensor.shape) - 1 global_scale = _calculate_global_scale(ctx, name, amax) - print(f"lan added global_scale: {input_tensor.shape=} {input_tensor.dtype=}") + print(f"lan added input_tensor: {input_tensor.shape=} {input_tensor.dtype=}") print(f"lan added global_scale: {global_scale.shape=} {global_scale.dtype=}") if ".weight_quantizer" in name: # _test_weights_scaling_factor(input_tensor, global_scale) @@ -107,6 +107,9 @@ def _dynamic_double_quantize( The data type for block scale. Default is FP8. """ + if os.getenv("DISABLE_DYNAMIC_QUANTIZE", "false").lower() == "true": + print("lan added disable_dynamic_quantize is set, skipping dynamic quantize") + return input_tensor global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") if input_tensor.dtype not in [trt.DataType.HALF, trt.DataType.FLOAT]: @@ -200,7 +203,9 @@ def _static_double_quantize( Returns: quantized data tensor in fp4 """ - + if os.getenv("DISABLE_STATIC_QUANTIZE", "false").lower() == "true": + print("lan added disable_static_quantize is set, skipping static quantize") + return get_trt_tensor(ctx, weights_tensor, name + "_weights") import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor if weights_tensor.dtype == torch.float16: diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index c42e30a460..f845351977 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -3,7 +3,6 @@ import platform import unittest from importlib import metadata - import pytest import timm import torch @@ -232,7 +231,7 @@ def calibrate_loop(model): print(f"lan added amax: {input_tensor.abs().amax()}") model = SimpleNetwork().eval().cuda() model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=torch.float16).cuda()) - model.linear1.bias = torch.nn.Parameter(torch.zeros(128, 32, dtype=torch.float16).cuda()) + model.linear1.bias = torch.nn.Parameter(torch.ones(128, 32, dtype=torch.float16).cuda()) output_pyt = model(input_tensor) print(f"lan added model input: {input_tensor=}") print(f"lan added model weight: {model.linear1.weight=}") @@ -274,7 +273,7 @@ def calibrate_loop(model): print(f"lan added torch_tensorrt outputs_trt: {outputs_trt}") abs_diff = torch.abs(output_pyt - outputs_trt) print(f"lan added max abs_diff: {abs_diff.max().item()}") - assert torch.allclose(output_pyt, outputs_trt, rtol=4e-1, atol=4e-1) + assert torch.allclose(output_pyt, outputs_trt, rtol=0.8, atol=0.8) @unittest.skipIf( From 7b3cd74bb2c0d8dc60927c8a1d0a2738f328975a Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Fri, 23 May 2025 09:51:22 -0700 Subject: [PATCH 26/30] add transpose --- .../dynamo/conversion/impl/nvfp4_quantize.py | 81 +++++++++++++++++-- tests/py/dynamo/models/test_models_export.py | 11 +-- 2 files changed, 75 insertions(+), 17 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py index 2f052969f7..91b61435e5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py @@ -14,6 +14,7 @@ from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor import os +import torch_tensorrt.dynamo.conversion.impl as impl def nvfp4_quantize( ctx: ConversionContext, @@ -206,6 +207,10 @@ def _static_double_quantize( if os.getenv("DISABLE_STATIC_QUANTIZE", "false").lower() == "true": print("lan added disable_static_quantize is set, skipping static quantize") return get_trt_tensor(ctx, weights_tensor, name + "_weights") + if os.getenv("ENABLE_TRANSPOSE", "false").lower() == "true": + enable_transpose = True + else: + enable_transpose = False import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor if weights_tensor.dtype == torch.float16: @@ -216,24 +221,34 @@ def _static_double_quantize( raise ValueError( f"Currently try float16, float32 only on weights tensor. Unsupported dtype: {weights_tensor.dtype}" ) - block_scale_fp8 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( + block_scale = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( weights_tensor, 16, global_scale, + keep_high_precision=True, )[0] - print(f"lan added global_scale: {global_scale.shape=} {global_scale.dtype=} {global_scale=}") - print(f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}") - weights_tensor_fp4 = nvfp4_tensor.NVFP4QTensor.quantize( + weights_tensor_scaled = nvfp4_tensor.NVFP4QTensor.quantize( weights_tensor, 16, - block_scale_fp8, + block_scale, global_scale, - )[0]._quantized_data - print(f"lan added weights_tensor_fp4: {weights_tensor_fp4.shape=} {weights_tensor_fp4.dtype=} {weights_tensor_fp4=}") + keep_high_precision=True, + ) + if enable_transpose: + block_scale = block_scale.transpose(0, 1) + weights_tensor_scaled = weights_tensor_scaled.transpose(0, 1) + + block_scale_fp8 = block_scale.to(torch.float8_e4m3fn) + weights_tensor_uint4 = nvfp4_tensor.NVFP4QTensor._cast_fp4(weights_tensor_scaled) + weights_tensor_uint8 = (weights_tensor_uint4[..., 1::2] << 4) | weights_tensor_uint4[..., 0::2] + + print(f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}") + print(f"lan added weights_tensor_uint8: {weights_tensor_uint8.shape=} {weights_tensor_uint8.dtype=}") + block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8") global_scale = to_torch(global_scale, None) global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") - weights_tensor_fp4 = get_trt_tensor(ctx, weights_tensor_fp4, name + "_weights_fp4") + weights_tensor_fp4 = get_trt_tensor(ctx, weights_tensor_uint8, name + "_weights_fp4") # dequantize block scale from fp8 to original dtype (default is float32) dequantize_block_scale_layer = ctx.net.add_dequantize( block_scale_fp8, @@ -268,6 +283,8 @@ def _static_double_quantize( f"lan added dequantize_data_layer: {dequantize_data_layer.to_type=} {dequantize_data_layer.axis=} {dequantize_data_layer.precision=} {dequantize_data_layer.get_output_type(0)=}" ) dequantized_data = dequantize_data_layer.get_output(0) + if enable_transpose: + dequantized_data = impl.permutation.permute(ctx, target, source_ir, name + "_dequantized_data_transposed", dequantized_data, (-1, -2)) return dequantized_data @@ -286,3 +303,51 @@ def _calculate_global_scale( if global_scale == 0: global_scale = 1.0 return global_scale + +def _get_weights_scaling_factor_transposed( + weights_tensor: torch.Tensor, + global_scale: torch.Tensor, + keep_high_precision: bool = False, +) -> torch.Tensor: + [k, n] = weights_tensor.shape[-2:] + assert k % 16 == 0, "Weight shape is not divisible for block size for block quantiation." + weights_tensor = weights_tensor.reshape(tuple(weights_tensor.shape[:-2]) + (k // 16, n, 16)) + per_block_amax = weights_tensor.abs().amax(dim=-1).float() + per_block_scale = per_block_amax / 6.0 + q_per_block_scale = per_block_scale / global_scale + q_per_block_scale[per_block_scale == 0] = 1.0 + if not keep_high_precision: + q_per_block_scale = q_per_block_scale.to(torch.float8_e4m3fn) + return q_per_block_scale + +def _quantized_weights_transposed( + input: torch.Tensor, + weights_scaling_factor: torch.Tensor, + weights_scaling_factor_2: torch.Tensor, + keep_high_precision: bool = False, +) -> torch.Tensor: + + # Reshape the weight and scale factors + input = input.view((*tuple(input.shape[:-1]), -1, block_size)) + + # Scale weights + scaled_weight = input / ( + (weights_scaling_factor.to(torch.float32) * weights_scaling_factor_2).unsqueeze(-1) + ) + + # Reshape weights to original + scaled_weight = scaled_weight.view((*tuple(scaled_weight.shape[:-2]), -1)) + + if keep_high_precision: + return scaled_weight + # Cast weights to fp4 + q_weight = cls._cast_fp4(scaled_weight) + # Pack weights + packed_weight = (q_weight[..., 1::2] << 4) | q_weight[..., 0::2] + return ( + cls(input_shape, input_dtype, packed_weight), + weights_scaling_factor, + weights_scaling_factor_2, + ) + + \ No newline at end of file diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index f845351977..b661c5fbf8 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -230,8 +230,8 @@ def calibrate_loop(model): print(f"lan added amax: {input_tensor.abs().amax()}") model = SimpleNetwork().eval().cuda() - model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=torch.float16).cuda()) - model.linear1.bias = torch.nn.Parameter(torch.ones(128, 32, dtype=torch.float16).cuda()) + #model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=torch.float16).cuda()) + #model.linear1.bias = torch.nn.Parameter(torch.ones(128, 32, dtype=torch.float16).cuda()) output_pyt = model(input_tensor) print(f"lan added model input: {input_tensor=}") print(f"lan added model weight: {model.linear1.weight=}") @@ -241,18 +241,11 @@ def calibrate_loop(model): quant_cfg = mtq.NVFP4_DEFAULT_CFG mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) # model has qdq nodes at this point - - torch.onnx.export(model, input_tensor, "mtq_model.onnx") - with torch.no_grad(): with export_torch_mode(): exp_program = torch.export.export(model, (input_tensor,), strict=False) from torch.fx import passes - g = passes.graph_drawer.FxGraphDrawer(exp_program, "torch_export_fp4") - with open("a.svg", "wb") as f: - f.write(g.get_dot_graph().create_svg()) - trt_model = torchtrt.dynamo.compile( exp_program, inputs=[input_tensor], From 1e2fccb74cf0c61e3cdda3be3325685604082630 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Fri, 23 May 2025 09:59:22 -0700 Subject: [PATCH 27/30] try different axis --- py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py index 91b61435e5..9221401a78 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py @@ -41,7 +41,7 @@ def nvfp4_quantize( f"nvfp4_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D" ) with unset_fake_temporarily(): - axis = len(input_tensor.shape) - 1 + axis = -1 global_scale = _calculate_global_scale(ctx, name, amax) print(f"lan added input_tensor: {input_tensor.shape=} {input_tensor.dtype=}") print(f"lan added global_scale: {global_scale.shape=} {global_scale.dtype=}") @@ -209,6 +209,7 @@ def _static_double_quantize( return get_trt_tensor(ctx, weights_tensor, name + "_weights") if os.getenv("ENABLE_TRANSPOSE", "false").lower() == "true": enable_transpose = True + axis = -2 else: enable_transpose = False import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor @@ -256,7 +257,6 @@ def _static_double_quantize( original_dtype, ) dequantize_block_scale_layer.axis = axis - # dequantize_block_scale_layer.precision = trt.DataType.FP8 dequantize_block_scale_layer.to_type = original_dtype set_layer_name( dequantize_block_scale_layer, @@ -276,7 +276,6 @@ def _static_double_quantize( original_dtype, ) dequantize_data_layer.axis = axis - # dequantize_data_layer.precision = trt.DataType.FP4 dequantize_data_layer.to_type = original_dtype set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) print( From c5e3498bd43f48dabb7033d8cd76220ed11bea51 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Fri, 23 May 2025 14:28:48 -0700 Subject: [PATCH 28/30] add disable gemm option --- .../dynamo/conversion/impl/addmm.py | 6 ++- .../dynamo/conversion/impl/nvfp4_quantize.py | 39 +++---------------- .../dynamo/conversion/impl/permutation.py | 6 ++- tests/py/dynamo/models/test_models_export.py | 30 +++++++++----- 4 files changed, 35 insertions(+), 46 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/addmm.py b/py/torch_tensorrt/dynamo/conversion/impl/addmm.py index 1a0690852a..73d98acfdf 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/addmm.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/addmm.py @@ -7,7 +7,7 @@ from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.types import TRTTensor - +import os def addmm( ctx: ConversionContext, @@ -21,6 +21,10 @@ def addmm( beta: Union[float, int], alpha: Union[float, int], ) -> TRTTensor: + if os.getenv("DISABLE_GEMM", "false").lower() == "true": + print("lan added disable_gemm is set, skip addmm and returning mat2") + return mat2 + print("lan added disable_gemm is not set, doing addmm") mm = impl.matmul.matrix_multiply(ctx, target, source_ir, f"{name}_mm", mat1, mat2) if alpha != 1: mm = impl.elementwise.mul( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py index 9221401a78..a2a7a0e3cc 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py @@ -111,6 +111,7 @@ def _dynamic_double_quantize( if os.getenv("DISABLE_DYNAMIC_QUANTIZE", "false").lower() == "true": print("lan added disable_dynamic_quantize is set, skipping dynamic quantize") return input_tensor + print("lan added disable_dynamic_quantize is not set, doing dynamic quantize") global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") if input_tensor.dtype not in [trt.DataType.HALF, trt.DataType.FLOAT]: @@ -207,7 +208,9 @@ def _static_double_quantize( if os.getenv("DISABLE_STATIC_QUANTIZE", "false").lower() == "true": print("lan added disable_static_quantize is set, skipping static quantize") return get_trt_tensor(ctx, weights_tensor, name + "_weights") + print("lan added static disable_static_quantize is not set, do disable_static_quantize ") if os.getenv("ENABLE_TRANSPOSE", "false").lower() == "true": + print("lan added enable_transpose is set, transposing weights tensor") enable_transpose = True axis = -2 else: @@ -265,7 +268,7 @@ def _static_double_quantize( source_ir, ) print( - f"lan added dequantize_block_scale_layer: {dequantize_block_scale_layer.to_type=} {dequantize_block_scale_layer.axis=} {dequantize_block_scale_layer.precision=} {dequantize_block_scale_layer.get_output_type(0)=}" + f"lan added dequantize_block_scale_layer: {dequantize_block_scale_layer.to_type=} {dequantize_block_scale_layer.axis=} {dequantize_block_scale_layer.get_input(0).shape=} {dequantize_block_scale_layer.get_input(1).shape=}" ) dequantized_block_scale = dequantize_block_scale_layer.get_output(0) @@ -279,7 +282,7 @@ def _static_double_quantize( dequantize_data_layer.to_type = original_dtype set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) print( - f"lan added dequantize_data_layer: {dequantize_data_layer.to_type=} {dequantize_data_layer.axis=} {dequantize_data_layer.precision=} {dequantize_data_layer.get_output_type(0)=}" + f"lan added dequantize_data_layer: {dequantize_data_layer.to_type=} {dequantize_data_layer.axis=} {dequantize_data_layer.get_input(0).shape=} {dequantize_data_layer.get_input(1).shape=}" ) dequantized_data = dequantize_data_layer.get_output(0) if enable_transpose: @@ -318,35 +321,3 @@ def _get_weights_scaling_factor_transposed( if not keep_high_precision: q_per_block_scale = q_per_block_scale.to(torch.float8_e4m3fn) return q_per_block_scale - -def _quantized_weights_transposed( - input: torch.Tensor, - weights_scaling_factor: torch.Tensor, - weights_scaling_factor_2: torch.Tensor, - keep_high_precision: bool = False, -) -> torch.Tensor: - - # Reshape the weight and scale factors - input = input.view((*tuple(input.shape[:-1]), -1, block_size)) - - # Scale weights - scaled_weight = input / ( - (weights_scaling_factor.to(torch.float32) * weights_scaling_factor_2).unsqueeze(-1) - ) - - # Reshape weights to original - scaled_weight = scaled_weight.view((*tuple(scaled_weight.shape[:-2]), -1)) - - if keep_high_precision: - return scaled_weight - # Cast weights to fp4 - q_weight = cls._cast_fp4(scaled_weight) - # Pack weights - packed_weight = (q_weight[..., 1::2] << 4) | q_weight[..., 0::2] - return ( - cls(input_shape, input_dtype, packed_weight), - weights_scaling_factor, - weights_scaling_factor_2, - ) - - \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index 1537d0fdbe..bc444f4e10 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -13,7 +13,7 @@ ) from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape from torch_tensorrt.fx.types import TRTTensor - +import os def permute( ctx: ConversionContext, @@ -23,6 +23,10 @@ def permute( input: TRTTensor, permutation: Sequence[int], ) -> TRTTensor: + if os.getenv("DISABLE_GEMM", "false").lower() == "true": + print("lan added disable_gemm is set, skip permute and returning mat2") + return mat2 + print("lan added disable_gemm is not set, doing permute") if not isinstance(input, TRTTensor): raise RuntimeError( f"permute received input {input} that is not a TensorRT ITensor" diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index b661c5fbf8..08cb092317 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -13,7 +13,7 @@ from packaging.version import Version assertions = unittest.TestCase() - +import os @pytest.mark.unit def test_resnet18(ir): @@ -228,16 +228,17 @@ def calibrate_loop(model): input_tensor = torch.ones(128, 64, dtype=torch.float16).cuda() - print(f"lan added amax: {input_tensor.abs().amax()}") + model = SimpleNetwork().eval().cuda() - #model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=torch.float16).cuda()) + model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=torch.float16).cuda()) #model.linear1.bias = torch.nn.Parameter(torch.ones(128, 32, dtype=torch.float16).cuda()) - output_pyt = model(input_tensor) + print(f"lan added amax: {input_tensor.abs().amax()=}") + print(f"lan added amax: {model.linear1.weight.abs().amax()=}") + expected_output = model(input_tensor) print(f"lan added model input: {input_tensor=}") print(f"lan added model weight: {model.linear1.weight=}") print(f"lan added model bias: {model.linear1.bias=}") - print(f"lan added pytorch output_pyt: {output_pyt} {output_pyt.dtype=} {output_pyt.shape=}") - + quant_cfg = mtq.NVFP4_DEFAULT_CFG mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) # model has qdq nodes at this point @@ -263,10 +264,19 @@ def calibrate_loop(model): ) outputs_trt = trt_model(input_tensor) - print(f"lan added torch_tensorrt outputs_trt: {outputs_trt}") - abs_diff = torch.abs(output_pyt - outputs_trt) - print(f"lan added max abs_diff: {abs_diff.max().item()}") - assert torch.allclose(output_pyt, outputs_trt, rtol=0.8, atol=0.8) + if os.getenv("DISABLE_GEMM", "false").lower() == "true": + print("lan added disable_gemm is set, compring result with weights") + expected_output = model.linear1.weight + else: + print("lan added disable_gemm is not set, compring result with pytorch") + + print(f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=}") + print(f"lan added pytorch output_pyt: {expected_output=} {outexpected_outputput_pyt.dtype=} {expected_output.shape=}") + + abs_diff = torch.abs(expected_output - outputs_trt) + print(f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}") + print(f"lan added abs_diff: {abs_diff=}") + assert torch.allclose(expected_output, outputs_trt, rtol=0.8, atol=0.8) @unittest.skipIf( From 925a68d3df376291ecb8a8f2ee27b7195e32305e Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Fri, 23 May 2025 14:40:35 -0700 Subject: [PATCH 29/30] test --- tests/py/dynamo/models/test_models_export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 08cb092317..c529c0b0dd 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -271,7 +271,7 @@ def calibrate_loop(model): print("lan added disable_gemm is not set, compring result with pytorch") print(f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=}") - print(f"lan added pytorch output_pyt: {expected_output=} {outexpected_outputput_pyt.dtype=} {expected_output.shape=}") + print(f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=}") abs_diff = torch.abs(expected_output - outputs_trt) print(f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}") From f526c0f2b2e11208e3897dc201576760c7023fda Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Fri, 23 May 2025 16:23:36 -0700 Subject: [PATCH 30/30] add float32 support --- py/torch_tensorrt/dynamo/_compiler.py | 12 ++++++------ .../dynamo/conversion/_TRTInterpreter.py | 17 +++++++---------- .../dynamo/conversion/impl/permutation.py | 4 ++-- tests/py/dynamo/models/test_models_export.py | 15 ++++++++------- 4 files changed, 23 insertions(+), 25 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index c4bc7238c9..4db11daa78 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -582,12 +582,12 @@ def compile( ) # if use_explicit_typing: - # if len(enabled_precisions) != 1 or not any( - # x in enabled_precisions for x in {torch.float32, dtype.f32} - # ): - # raise AssertionError( - # f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}" - # ) + # if len(enabled_precisions) != 1 or not any( + # x in enabled_precisions for x in {torch.float32, dtype.f32} + # ): + # raise AssertionError( + # f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}" + # ) if use_fp32_acc: logger.debug( diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index a9e7d194e1..ecf08f38c4 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -87,7 +87,6 @@ def __init__( flag = 0 if compilation_settings.use_explicit_typing: - _LOGGER.info("Using strongly typed network definition") STRONGLY_TYPED = 1 << (int)( trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED ) @@ -275,19 +274,17 @@ def _populate_trt_builder_config( self.compilation_settings.dla_global_dram_size, ) - # if dtype.float16 in self.compilation_settings.enabled_precisions: - # builder_config.set_flag(trt.BuilderFlag.FP16) + if not self.compilation_settings.use_explicit_typing and dtype.float16 in self.compilation_settings.enabled_precisions: + builder_config.set_flag(trt.BuilderFlag.FP16) if dtype.int8 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.INT8) - # if dtype.fp8 in self.compilation_settings.enabled_precisions: - # builder_config.set_flag(trt.BuilderFlag.FP8) + if not self.compilation_settings.use_explicit_typing and dtype.fp8 in self.compilation_settings.enabled_precisions: + builder_config.set_flag(trt.BuilderFlag.FP8) - # if dtype.fp4 in self.compilation_settings.enabled_precisions: - # builder_config.set_flag(trt.BuilderFlag.FP4) - # if dtype.bfloat16 in self.compilation_settings.enabled_precisions: - # builder_config.set_flag(trt.BuilderFlag.BF16) + if dtype.bfloat16 in self.compilation_settings.enabled_precisions: + builder_config.set_flag(trt.BuilderFlag.BF16) if self.compilation_settings.sparse_weights: builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS) @@ -354,7 +351,7 @@ def _populate_trt_builder_config( builder_config.l2_limit_for_tiling = ( self.compilation_settings.l2_limit_for_tiling ) - print(f"lan added builder_config:{builder_config=}") + return builder_config def _create_timing_cache( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index bc444f4e10..4408b62809 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -24,8 +24,8 @@ def permute( permutation: Sequence[int], ) -> TRTTensor: if os.getenv("DISABLE_GEMM", "false").lower() == "true": - print("lan added disable_gemm is set, skip permute and returning mat2") - return mat2 + print("lan added disable_gemm is set, skip permute") + return input print("lan added disable_gemm is not set, doing permute") if not isinstance(input, TRTTensor): raise RuntimeError( diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index c529c0b0dd..175d3d79d7 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -210,12 +210,13 @@ def test_resnet18_half(ir): def test_base_fp4(ir): import modelopt.torch.quantization as mtq from modelopt.torch.quantization.utils import export_torch_mode + dtype = torch.float16 class SimpleNetwork(torch.nn.Module): def __init__(self): super(SimpleNetwork, self).__init__() self.linear1 = torch.nn.Linear( - in_features=64, out_features=32, bias=True, dtype=torch.float16 + in_features=64, out_features=32, bias=True, dtype=dtype ) def forward(self, x): @@ -226,12 +227,12 @@ def calibrate_loop(model): """Simple calibration function for testing.""" model(input_tensor) - input_tensor = torch.ones(128, 64, dtype=torch.float16).cuda() + input_tensor = torch.ones(128, 64, dtype=dtype).cuda() model = SimpleNetwork().eval().cuda() - model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=torch.float16).cuda()) - #model.linear1.bias = torch.nn.Parameter(torch.ones(128, 32, dtype=torch.float16).cuda()) + model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=dtype).cuda()) + model.linear1.bias = torch.nn.Parameter(torch.zeros(128, 32, dtype=dtype).cuda()) print(f"lan added amax: {input_tensor.abs().amax()=}") print(f"lan added amax: {model.linear1.weight.abs().amax()=}") expected_output = model(input_tensor) @@ -260,7 +261,7 @@ def calibrate_loop(model): debug=True, cache_built_engines=False, reuse_cached_engines=False, - use_explicit_typing=True, + use_explicit_typing=dtype == torch.float16, ) outputs_trt = trt_model(input_tensor) @@ -270,8 +271,8 @@ def calibrate_loop(model): else: print("lan added disable_gemm is not set, compring result with pytorch") - print(f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=}") - print(f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=}") + print(f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}") + print(f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}") abs_diff = torch.abs(expected_output - outputs_trt) print(f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}")