From fa64d54ab513db68a3ab700ecea34eb3eb051322 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 23 Jul 2025 16:36:05 -0700 Subject: [PATCH 1/2] add more generic kernel for fp8 blockwise scaling stack-info: PR: https://github.com/pytorch/ao/pull/2592, branch: danielvegamyhre/stack/15 --- ...enchmark_blockwise_scaled_linear_triton.py | 2 +- .../bench_fp8_blockwise_quant_kernels.py | 169 +++++++++++++ benchmarks/float8/utils.py | 10 + test/prototype/test_blockwise_triton.py | 72 ------ test/prototype/test_fp8_blockwise_kernels.py | 157 ++++++++++++ torchao/prototype/blockwise_fp8/__init__.py | 2 +- .../blockwise_fp8/blockwise_linear.py | 2 +- .../{blockwise_quantization.py => kernels.py} | 230 +++++++++++++++++- 8 files changed, 568 insertions(+), 76 deletions(-) create mode 100644 benchmarks/float8/bench_fp8_blockwise_quant_kernels.py delete mode 100644 test/prototype/test_blockwise_triton.py create mode 100644 test/prototype/test_fp8_blockwise_kernels.py rename torchao/prototype/blockwise_fp8/{blockwise_quantization.py => kernels.py} (56%) diff --git a/benchmarks/benchmark_blockwise_scaled_linear_triton.py b/benchmarks/benchmark_blockwise_scaled_linear_triton.py index 809202170a..bdfae9d149 100644 --- a/benchmarks/benchmark_blockwise_scaled_linear_triton.py +++ b/benchmarks/benchmark_blockwise_scaled_linear_triton.py @@ -13,7 +13,7 @@ from triton.testing import do_bench from torchao.float8.float8_utils import compute_error - from torchao.prototype.blockwise_fp8.blockwise_quantization import ( + from torchao.prototype.blockwise_fp8.kernels import ( blockwise_fp8_gemm, fp8_blockwise_act_quant, fp8_blockwise_weight_quant, diff --git a/benchmarks/float8/bench_fp8_blockwise_quant_kernels.py b/benchmarks/float8/bench_fp8_blockwise_quant_kernels.py new file mode 100644 index 0000000000..1f65d98d3e --- /dev/null +++ b/benchmarks/float8/bench_fp8_blockwise_quant_kernels.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py +import argparse +import itertools +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from tqdm import tqdm +from utils import benchmark_microseconds + +from torchao.prototype.blockwise_fp8.kernels import ( + fp8_blockwise_act_quant, + fp8_blockwise_weight_quant, + torch_blockwise_scale_act_quant, + torch_blockwise_scale_weight_quant, + triton_quantize_fp8_block, +) + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + A_shape: tuple[int] + block_m: int + block_k: int + + +@dataclass(frozen=True) +class ExperimentResult: + torch_us: float + fbgemm_us: float + deepgemm_us: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + A_shapes = [ + (1024, 1024), + (2048, 2048), + (4096, 4096), + (8192, 8192), + (16384, 16384), + (32768, 32768), + ] + block_m_opts = [1, 128] + block_k_opts = [ + 128, + ] + configs = [] + for A_shape, block_m, block_k in itertools.product( + A_shapes, + block_m_opts, + block_k_opts, + ): + configs.append( + ExperimentConfig( + A_shape=A_shape, + block_m=block_m, + block_k=block_k, + ) + ) + return configs + + +def run_experiment( + config: ExperimentConfig, args: argparse.Namespace +) -> ExperimentResult: + A = torch.randn( + *config.A_shape, + dtype=torch.bfloat16, + device=device, + ) + + # Torch and DeepGEMM implementations are specific to activation quantization (1 x block_size) + # and weight quantization (block_size x block_size) + if config.block_m == 1: + torch_func = torch.compile(torch_blockwise_scale_act_quant) + deepgemm_func = fp8_blockwise_act_quant + else: + torch_func = torch.compile(torch_blockwise_scale_weight_quant) + deepgemm_func = fp8_blockwise_weight_quant + + # Validate output shapes and strides + torch_out, torch_scale = torch_func(A, tile_size=config.block_k) + deepgemm_out, deepgemm_scale = deepgemm_func(A, block_size=config.block_k) + fbgemm_out, fbgemm_scale = triton_quantize_fp8_block( + A, block_m=config.block_m, block_k=config.block_k, k_major=True + ) + assert torch_out.shape == deepgemm_out.shape == fbgemm_out.shape + assert torch_out.stride() == deepgemm_out.stride() == fbgemm_out.stride() + assert torch_scale.shape == deepgemm_scale.shape == fbgemm_scale.shape + assert torch_scale.stride() == deepgemm_scale.stride() == fbgemm_scale.stride() + + # Do benchmarking + torch_us = benchmark_microseconds(torch_func, A, tile_size=config.block_k) + deepgemm_us = benchmark_microseconds( + fp8_blockwise_act_quant, A, block_size=config.block_k + ) + fbgemm_us = benchmark_microseconds( + triton_quantize_fp8_block, + A, + block_m=config.block_m, + block_k=config.block_k, + k_major=True, + ) + + return ExperimentResult( + torch_us=round(torch_us, 3), + fbgemm_us=round(fbgemm_us, 3), + deepgemm_us=round(deepgemm_us, 3), + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "A_shape", + "block_shape", + "torch_us", + "fbgemm_us", + "deepgemm_us", + ] + rows = [] + for experiment in experiments: + A_shape = f"({experiment.config.A_shape[0]}, {experiment.config.A_shape[1]})" + block_shape = f"({experiment.config.block_m},{experiment.config.block_k})" + rows.append( + [ + A_shape, + block_shape, + experiment.result.torch_us, + experiment.result.fbgemm_us, + experiment.result.deepgemm_us, + ] + ) + print(tabulate(rows, headers=headers)) + + +def main(args: argparse.Namespace): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config, args) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument("--compile", action="store_true") + args = arg_parser.parse_args() + main(args) diff --git a/benchmarks/float8/utils.py b/benchmarks/float8/utils.py index d4cdfeef20..1a79baec5d 100644 --- a/benchmarks/float8/utils.py +++ b/benchmarks/float8/utils.py @@ -11,6 +11,7 @@ import torch.utils.benchmark as benchmark from torch.profiler import ProfilerActivity, profile +from triton.testing import do_bench def profiler_output_to_filtered_time_by_kernel_name( @@ -428,3 +429,12 @@ def do_benchmarks( tops_sec = float(tops) / time_sec pct_top_peak = tops_sec / peak_tops return time_sec, tops_sec, pct_top_peak + + +def benchmark_microseconds(f, *args, warmup=25, rep=100, **kwargs): + return ( + do_bench( + lambda: f(*args, **kwargs), warmup=warmup, rep=rep, return_mode="median" + ) + * 1e3 + ) diff --git a/test/prototype/test_blockwise_triton.py b/test/prototype/test_blockwise_triton.py deleted file mode 100644 index 8aab73f7e8..0000000000 --- a/test/prototype/test_blockwise_triton.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -import pytest -import torch - -from packaging import version - -triton = pytest.importorskip("triton", reason="Triton required to run this test") - -from torchao.prototype.blockwise_fp8.blockwise_quantization import ( - blockwise_fp8_gemm, - fp8_blockwise_act_quant, - fp8_blockwise_weight_dequant, - fp8_blockwise_weight_quant, -) -from torchao.utils import is_sm_at_least_89 - -BLOCKWISE_SIZE_MNK = [ - (2, 512, 128), - (3, 2048, 2048), - (4, 3584, 640), - (13, 8704, 8576), - (26, 18944, 1664), - (67, 6656, 1408), -] - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize("_, N, K", BLOCKWISE_SIZE_MNK) -@pytest.mark.parametrize( - "dtype", - [torch.float8_e4m3fn, torch.float8_e5m2] - if is_sm_at_least_89() - else [torch.float8_e5m2], -) -def test_blockwise_quant_dequant(_, N, K, dtype): - x = torch.randn(N, K).cuda() - qx, s = fp8_blockwise_weight_quant(x, dtype=dtype) - x_reconstructed = fp8_blockwise_weight_dequant(qx, s) - error = torch.norm(x - x_reconstructed) / torch.norm(x) - print(f"Relative Error: {error.item():.6f}") - - assert error < 0.1, "Quant-Dequant error is too high" - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif( - version.parse(triton.__version__) < version.parse("3.3.0"), - reason="Triton version < 3.3.0, test skipped", -) -@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK) -@pytest.mark.parametrize( - "dtype", - [torch.float8_e4m3fn, torch.float8_e5m2] - if is_sm_at_least_89() - else [torch.float8_e5m2], -) -def test_blockwise_fp8_gemm(M, N, K, dtype): - A = torch.randn(M, K).cuda() - B = torch.randn(N, K).cuda() - C = A @ B.T - A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype) - B_q, B_s = fp8_blockwise_weight_quant(B, dtype=dtype) - C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s) - error = torch.norm(C - C_q) / torch.norm(C) - print(f"Relative Error: {error.item():.6f}") - - assert error < 0.1, "Quantize gemm error is too high" diff --git a/test/prototype/test_fp8_blockwise_kernels.py b/test/prototype/test_fp8_blockwise_kernels.py new file mode 100644 index 0000000000..1081d5dfc7 --- /dev/null +++ b/test/prototype/test_fp8_blockwise_kernels.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from packaging import version +from torchao.float8.float8_utils import compute_error +from torchao.prototype.blockwise_fp8.kernels import ( + torch_blockwise_scale_act_quant, + torch_blockwise_scale_weight_quant, + triton_quantize_fp8_block, +) +from torchao.testing.utils import skip_if_rocm + +triton = pytest.importorskip("triton", reason="Triton required to run this test") + +from torchao.prototype.blockwise_fp8.kernels import ( + blockwise_fp8_gemm, + fp8_blockwise_act_quant, + fp8_blockwise_weight_dequant, + fp8_blockwise_weight_quant, +) +from torchao.utils import is_sm_at_least_89 + +BLOCKWISE_SIZE_MNK = [ + (2, 512, 128), + (3, 2048, 2048), + (4, 3584, 640), + (13, 8704, 8576), + (26, 18944, 1664), + (67, 6656, 1408), +] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("_, N, K", BLOCKWISE_SIZE_MNK) +@pytest.mark.parametrize( + "dtype", + [torch.float8_e4m3fn, torch.float8_e5m2] + if is_sm_at_least_89() + else [torch.float8_e5m2], +) +def test_blockwise_quant_dequant(_, N, K, dtype): + x = torch.randn(N, K).cuda() + qx, s = fp8_blockwise_weight_quant(x, dtype=dtype) + x_reconstructed = fp8_blockwise_weight_dequant(qx, s) + sqnr = compute_error(x, x_reconstructed) + assert sqnr >= 25.0, f"SQNR {sqnr:.2f} must be >= 25.0" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + version.parse(triton.__version__) < version.parse("3.3.0"), + reason="Triton version < 3.3.0, test skipped", +) +@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK) +@pytest.mark.parametrize( + "dtype", + [torch.float8_e4m3fn, torch.float8_e5m2] + if is_sm_at_least_89() + else [torch.float8_e5m2], +) +def test_blockwise_fp8_gemm(M, N, K, dtype): + A = torch.randn(M, K).cuda() + B = torch.randn(N, K).cuda() + C = A @ B.T + A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype) + B_q, B_s = fp8_blockwise_weight_quant(B, dtype=dtype) + C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s) + sqnr = compute_error(C, C_q) + assert sqnr >= 25.0, f"SQNR {sqnr:.2f} must be >= 25.0" + + +@skip_if_rocm("ROCm enablement in progress") +@pytest.mark.parametrize("tile_size", [128, 256]) +def test_triton_quantize_fp8_act_quant(tile_size: int): + device = "cuda" + M, K = 256, 256 + x = torch.randn(M, K, device=device) + + # Get the quantized tensor and scales using triton implementation + # Use block_m=1 to match the narrow tiles (1 x tile_size) in the reference implementation + triton_fp8, triton_scale = triton_quantize_fp8_block( + x, block_m=1, block_k=tile_size + ) + + # Get the quantized tensor and scales using reference implementation + ref_fp8, ref_scale = torch_blockwise_scale_act_quant(x, tile_size=tile_size) + + # Convert both to float32 for comparison + triton_fp32 = triton_fp8.to(torch.float32) + ref_fp32 = ref_fp8.to(torch.float32) + + # Check that the quantized tensors are close + # Note: We use a relatively high tolerance because the implementations might have + # slight differences in how they handle edge cases, rounding, etc. + assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-2, atol=1e-2), ( + f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}" + ) + + # Check that the scales are close + # Note: The scales might be stored differently (reciprocal vs. direct), so we need to + # be careful about how we compare them + + # In triton_quantize_fp8_block, scales are stored as reciprocals (1/scale) + # In torch_blockwise_scale_act_quant, scales are stored directly + # So we need to take the reciprocal of one of them for comparison + + # Reshape triton_scale to match ref_scale shape for comparison + triton_scale_reshaped = triton_scale.reshape(M, -1) + + # Compare reciprocal of triton_scale with ref_scale + assert torch.allclose( + 1.0 / triton_scale_reshaped, ref_scale, rtol=1e-2, atol=1e-2 + ), ( + f"Scales differ: max diff = {(1.0 / triton_scale_reshaped - ref_scale).abs().max().item()}" + ) + + +@skip_if_rocm("ROCm enablement in progress") +@pytest.mark.parametrize("tile_size", [128, 256]) +def test_triton_quantize_fp8_weight_quant(tile_size: int): + device = "cuda" + # Make sure dimensions are multiples of tile_size for clean comparison + M = tile_size * 2 + K = tile_size * 2 + x = torch.randn(M, K, device=device) + + # Get the quantized tensor and scales using triton implementation + triton_fp8, triton_scale = triton_quantize_fp8_block( + x, block_m=tile_size, block_k=tile_size + ) + + # Get the quantized tensor and scales using reference implementation + ref_fp8, ref_scale = torch_blockwise_scale_weight_quant(x, tile_size=tile_size) + + # Convert both to float32 for comparison + triton_fp32 = triton_fp8.to(torch.float32) + ref_fp32 = ref_fp8.to(torch.float32) + + # Check that the quantized tensors are close + assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-2, atol=1e-2), ( + f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}" + ) + + # Check that the scales are close + # In triton_quantize_fp8_block, scales are stored as reciprocals (1/scale) + # In torch_blockwise_scale_weight_quant, scales are stored directly + + # Compare reciprocal of triton_scale with ref_scale + assert torch.allclose(1.0 / triton_scale, ref_scale, rtol=1e-2, atol=1e-2), ( + f"Scales differ: max diff = {(1.0 / triton_scale - ref_scale).abs().max().item()}" + ) diff --git a/torchao/prototype/blockwise_fp8/__init__.py b/torchao/prototype/blockwise_fp8/__init__.py index f2842417e4..b3bb87b762 100644 --- a/torchao/prototype/blockwise_fp8/__init__.py +++ b/torchao/prototype/blockwise_fp8/__init__.py @@ -1,5 +1,5 @@ from .blockwise_linear import BlockwiseQuantLinear -from .blockwise_quantization import ( +from .kernels import ( blockwise_fp8_gemm, fp8_blockwise_act_quant, fp8_blockwise_weight_dequant, diff --git a/torchao/prototype/blockwise_fp8/blockwise_linear.py b/torchao/prototype/blockwise_fp8/blockwise_linear.py index c25b946732..a911de22cf 100644 --- a/torchao/prototype/blockwise_fp8/blockwise_linear.py +++ b/torchao/prototype/blockwise_fp8/blockwise_linear.py @@ -7,7 +7,7 @@ import torch from torch import nn -from torchao.prototype.blockwise_fp8.blockwise_quantization import ( +from torchao.prototype.blockwise_fp8.kernels import ( blockwise_fp8_gemm, fp8_blockwise_act_quant, ) diff --git a/torchao/prototype/blockwise_fp8/blockwise_quantization.py b/torchao/prototype/blockwise_fp8/kernels.py similarity index 56% rename from torchao/prototype/blockwise_fp8/blockwise_quantization.py rename to torchao/prototype/blockwise_fp8/kernels.py index 1d296249f9..3a296ace5a 100644 --- a/torchao/prototype/blockwise_fp8/blockwise_quantization.py +++ b/torchao/prototype/blockwise_fp8/kernels.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import math -from typing import Tuple +from typing import Optional, Tuple import torch import triton @@ -277,3 +277,231 @@ def fp8_blockwise_weight_dequant( ) fp8_blockwise_weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) return y + + +# original implementation from fbgemm_gpu: +# https://github.com/pytorch/FBGEMM/blob/b19401e913fcdff536dc097fa3013a0a9d66256e/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py#L3091 +def triton_quantize_fp8_block( + x: torch.Tensor, + block_m: int = 128, + block_k: int = 128, + scale_ub: Optional[torch.Tensor] = None, + k_major: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize a tensor to fp8 with block-wise scalings. + + Scale per block i, j is computed as 1 / (MAX_FP8 / max(abs(x[i:i+block_m, j:j+block_k]))) + + Args: + x (torch.Tensor): [M, K] higher precision input tensor. + block_m (int): Block size for M dimension of scale. + block_k (int): Block size for K dimension of scale. + scale_ub: Maximum allowed value for scale. + k_major (bool): Whether output scales should be K major (True) or MN major (False). + + Returns: + torch.Tensor : [M, K] fp8 scaled tensor. + torch.Tensor: [cdiv(M, block_m), cdiv(K, block_k)] reciprocal scale tensor per block + if k_major is True, otherwise [cdiv(K, block_k), cdiv(M, block_M)]. + """ + assert x.device != torch.device("cpu"), ( + "Blockwise quantization not support on cpu, please use row-wise quantization instead." + ) + pt_dtype = torch.float8_e4m3fn + tl_dtype = tl.float8e4nv + max_fp8 = torch.finfo(pt_dtype).max + eps = 1e-12 + + x_shape = x.shape + x = x.view(-1, x.size(-1)) + M, K = x.shape + grid_m = triton.cdiv(M, block_m) + grid_k = triton.cdiv(K, block_k) + if k_major: + x_scale = torch.empty((grid_m, grid_k), device=x.device, dtype=torch.float32) + else: + x_scale = torch.empty((grid_k, grid_m), device=x.device, dtype=torch.float32) + x_fp8 = torch.empty((M, K), device=x.device, dtype=pt_dtype) + + _kernel_quantize_fp8_block[(grid_m * grid_k,)]( + x, + x_scale, + x_fp8, + scale_ub, + M, + K, + x.stride(0), + x.stride(1), + x_fp8.stride(0), + x_fp8.stride(1), + x_scale.stride(0), + x_scale.stride(1), + TL_FP8_DTYPE=tl_dtype, + MAX_FP8=max_fp8, + EPS=eps, + CLAMP_MAX=scale_ub is not None, + BLOCK_M=block_m, + BLOCK_K=block_k, + K_MAJOR=k_major, + ) + + return x_fp8.view(x_shape), x_scale + + +# original implementation from fbgemm_gpu: +# https://github.com/pytorch/FBGEMM/blob/b19401e913fcdff536dc097fa3013a0a9d66256e/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py#L3005 +@triton.jit +def _kernel_quantize_fp8_block( + A, + A_scale, + A_fp8, + scale_ub, + M, + K, + stride_am, + stride_ak, + stride_om, + stride_ok, + stride_a_scale_m, + stride_a_scale_k, + TL_FP8_DTYPE: tl.constexpr, + MAX_FP8: tl.constexpr, + EPS: tl.constexpr, + CLAMP_MAX: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + K_MAJOR: tl.constexpr, +) -> None: + """Quantize and scale each [BLOCK_M, BLOCK_K] block. + + Scale per block i, j is computed as 1 / (MAX_FP8 / max(abs(A[i:i+BLOCK_M, j:j+BLOCK_K]))) + + Kernel naively iterates through matrix with [BLOCK_M, BLOCK_K] tiles. + + Todo: + * Better tiling and ordering schemes. + + Args: + A (Tensor): [M, K] higher precision input tensor. + A_scale (Tensor): [cdiv(M, BLOCK_M), cdiv(K, BLOCK_K)] reciprocal scale tensor per block. + A_fp8 (Tensor): [M, K] fp8 scaled tensor. A_fp8 = A * a_scale + scale_ub (Tensor): [1] Maximum allowed value for scale. + M (int): Number of rows. + K (int): Number of columns. + stride_am (int): Stride of m dimension of A. + stride_ak (int): Stride of k dimension of A. + stride_om (int): Stride of m dimension of output. + stride_ok (int): Stride of k dimension of output. + stride_a_scale_m (int): Stride of m dimension of A_scale. + stride_a_scale_k (int): Stride of k dimension of A_scale. + TL_FP8_DTYPE (tl.dtype): Target fp8 datatype. + MAX_FP8 (float): Maxmimum expressible value for FP8. + EPS (float): Epsilon value for numerical stability. + CLAMP_MAX (bool): Whether to apply scale_ub. + BLOCK_M (int): Block size for M dimension of A_scale and kernel. + BLOCK_K (int): Block size for K dimension of A_scale and kernel. + K_MAJOR (bool): Whether output scales should be K major (True) or MN major (False). + """ + pid = tl.program_id(0) + grid_k = tl.cdiv(K, BLOCK_K) + block_m = pid // grid_k + block_k = pid % grid_k + rm = block_m * BLOCK_M + tl.arange(0, BLOCK_M) + rk = block_k * BLOCK_K + tl.arange(0, BLOCK_K) + a_offset = rm[:, None] * stride_am + rk[None, :] * stride_ak + out_offset = rm[:, None] * stride_om + rk[None, :] * stride_ok + a_mask = (rm < M)[:, None] & (rk < K)[None, :] + a_block = tl.load(A + a_offset, mask=a_mask, other=0.0) + + block_max = tl.max(tl.abs(a_block)) + # Apply appropriate clamping. + if CLAMP_MAX: + ub = tl.load(scale_ub) + block_max = tl.clamp(block_max, EPS, ub) + else: + block_max = tl.maximum(block_max, EPS) + scale = MAX_FP8 / block_max + + # Write in transposed order if specified. + if K_MAJOR: + scale_offset = block_m * stride_a_scale_m + block_k * stride_a_scale_k + else: + scale_offset = block_k * stride_a_scale_m + block_m * stride_a_scale_k + tl.store(A_scale + scale_offset, 1.0 / scale) + a_fp8 = a_block * scale + # Clamp A to fp8 range to make sure there's no overflow. + # This is required for AMD. Nvidia's default saturation + # handles it, but it's nice to have anyway. + a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8) + a_fp8.to(TL_FP8_DTYPE) + tl.store(A_fp8 + out_offset, a_fp8, mask=a_mask) + + +def torch_blockwise_scale_act_quant(x, tile_size=128): + """ + Input: weight tensor in high precision + Output: weight tensor in float8, and scale, tiled 1 by tile_size + """ + assert x.is_contiguous(), "input tensor must be contiguous" + orig_shape = x.shape + + # Reshape 2D+ input tensor into 2D tensor with shape (leading_dims, tile_size) + x = x.reshape(-1, tile_size) + + # Compute amax along last dim (i.e., the block) + x_amax = x.abs().max(dim=1).values.unsqueeze(1).clamp(1e-4) + + # Convert amax to scale + fp8_dtype_max, fp8_dtype_min = ( + torch.finfo(torch.float8_e4m3fn).max, + torch.finfo(torch.float8_e4m3fn).min, + ) + s = fp8_dtype_max / x_amax + + # Apply scale and clamp + x = (x * s).clamp(min=fp8_dtype_min, max=fp8_dtype_max).to(torch.float8_e4m3fn) + + # Reshape quantized output back to original shape and reshape scales accordingly + x = x.reshape(*orig_shape) + s = s.reshape(orig_shape[0], -1).to(torch.float) + return x, s + + +def torch_blockwise_scale_weight_quant(x, tile_size=128): + """ + Input: weight tensor in high precision + Output: weight tensor in float8, and scale, tiled tile_size by tile_size + """ + assert len(x.shape) == 2, "input shape must be 2D" + assert x.is_contiguous(), "input tensor must be contiguous" + height, width = x.shape + + # Compute block sizes + t_h = height // tile_size + t_w = width // tile_size + + # Reshape 2D input tensor into 4D tensor with shape (t_h, t_w, tile_size * tile_size) + x = x.reshape(t_h, tile_size, t_w, tile_size) + x = x.permute(0, 2, 1, 3) + x = x.reshape(-1, tile_size * tile_size) + + # Compute amax along last dim (i.e., the block) + m = x.abs().max(dim=1).values.unsqueeze(1).clamp(1e-4) + + # Convert amax to scale + fp8_dtype_max, fp8_dtype_min = ( + torch.finfo(torch.float8_e4m3fn).max, + torch.finfo(torch.float8_e4m3fn).min, + ) + s = fp8_dtype_max / m + + # Apply scale and clamp + x = (x * s).clamp(min=fp8_dtype_min, max=fp8_dtype_max).to(torch.float8_e4m3fn) + + # Reshape quantized output and scales back to 2D + x = x.reshape(t_h, t_w, tile_size, tile_size) + x = x.permute(0, 2, 1, 3) + x = x.reshape(height, width) + s = s.reshape(t_h, t_w).to(torch.float) + return x, s From 5a4f0faab2c77094dbcb8fac78cf6363e69d4888 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 24 Jul 2025 19:50:50 -0700 Subject: [PATCH 2/2] make fp8 blockwise linear differentiable; use new kernels stack-info: PR: https://github.com/pytorch/ao/pull/2602, branch: danielvegamyhre/stack/16 --- .../blockwise_fp8/test_blockwise_linear.py | 65 +++++++ .../test_fp8_blockwise_kernels.py | 0 torchao/prototype/blockwise_fp8/__init__.py | 4 +- .../blockwise_fp8/blockwise_linear.py | 169 ++++++++++++++---- .../blockwise_fp8/deep_gemm_utils.py | 29 +++ torchao/prototype/blockwise_fp8/kernels.py | 6 + 6 files changed, 237 insertions(+), 36 deletions(-) create mode 100644 test/prototype/blockwise_fp8/test_blockwise_linear.py rename test/prototype/{ => blockwise_fp8}/test_fp8_blockwise_kernels.py (100%) create mode 100644 torchao/prototype/blockwise_fp8/deep_gemm_utils.py diff --git a/test/prototype/blockwise_fp8/test_blockwise_linear.py b/test/prototype/blockwise_fp8/test_blockwise_linear.py new file mode 100644 index 0000000000..00679841cb --- /dev/null +++ b/test/prototype/blockwise_fp8/test_blockwise_linear.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from torchao.float8.float8_utils import compute_error +from torchao.prototype.blockwise_fp8.blockwise_linear import Float8BlockwiseLinear + +triton = pytest.importorskip("triton", reason="Triton required to run this test") + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("in_features", [1024]) +@pytest.mark.parametrize("out_features", [1024]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("block_size", [128]) +def test_blockwise_quant_linear_fwd_bwd( + in_features, + out_features, + batch_size, + block_size, +): + if in_features % block_size != 0 or out_features % block_size != 0: + pytest.skip(f"Dimensions must be divisible by block_size={block_size}") + + torch.random.manual_seed(0) + layer_test = Float8BlockwiseLinear( + in_features=in_features, + out_features=out_features, + block_size=block_size, + ).cuda() + + torch.random.manual_seed(0) + layer_ref = torch.nn.Linear( + in_features=in_features, + out_features=out_features, + ).cuda() + + # Create input tensor + x_test = torch.randn(batch_size, in_features).cuda() + x_ref = x_test.clone().detach().requires_grad_(True) + + # Forward pass + y_test = layer_test(x_test) + y_ref = layer_ref(x_ref) + + # Compare outputs + sqnr = compute_error(y_ref, y_test) + assert sqnr >= 25.0, f"SQNR: {sqnr.item()} must be >= 25.0" + + # Backward pass + y_test.sum().backward() + y_ref.sum().backward() + + # Compare input grads + sqnr = compute_error(x_ref.grad, x_test.grad) + assert sqnr >= 25.0, f"SQNR: {sqnr} must be >= 25.0" + + # Compare weight grads + sqnr = compute_error(layer_ref.weight, layer_test.weight) + assert sqnr >= 25.0, f"SQNR: {sqnr} must be >= 25.0" diff --git a/test/prototype/test_fp8_blockwise_kernels.py b/test/prototype/blockwise_fp8/test_fp8_blockwise_kernels.py similarity index 100% rename from test/prototype/test_fp8_blockwise_kernels.py rename to test/prototype/blockwise_fp8/test_fp8_blockwise_kernels.py diff --git a/torchao/prototype/blockwise_fp8/__init__.py b/torchao/prototype/blockwise_fp8/__init__.py index b3bb87b762..2a09596181 100644 --- a/torchao/prototype/blockwise_fp8/__init__.py +++ b/torchao/prototype/blockwise_fp8/__init__.py @@ -1,4 +1,4 @@ -from .blockwise_linear import BlockwiseQuantLinear +from .blockwise_linear import Float8BlockwiseLinear from .kernels import ( blockwise_fp8_gemm, fp8_blockwise_act_quant, @@ -8,7 +8,7 @@ __all__ = [ "blockwise_fp8_gemm", - "BlockwiseQuantLinear", + "Float8BlockwiseLinear", "fp8_blockwise_act_quant", "fp8_blockwise_weight_quant", "fp8_blockwise_weight_dequant", diff --git a/torchao/prototype/blockwise_fp8/blockwise_linear.py b/torchao/prototype/blockwise_fp8/blockwise_linear.py index a911de22cf..8f721b8e59 100644 --- a/torchao/prototype/blockwise_fp8/blockwise_linear.py +++ b/torchao/prototype/blockwise_fp8/blockwise_linear.py @@ -7,13 +7,107 @@ import torch from torch import nn +from torchao.core.config import AOBaseConfig +from torchao.prototype.blockwise_fp8.deep_gemm_utils import ( + scaled_mm_deep_gemm_128_1_128_1, + scaled_mm_deep_gemm_128_1_128_128, +) from torchao.prototype.blockwise_fp8.kernels import ( - blockwise_fp8_gemm, fp8_blockwise_act_quant, + triton_quantize_fp8_block, +) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, ) -class BlockwiseQuantLinear(nn.Module): +class fp8_blockwise_mm(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, block_size): + assert block_size == 128, "Only support block_size=128" + + # Temporarily reshape x to 2D tensor + x_orig_shape = x.shape + x = x.reshape(-1, x_orig_shape[-1]) + + # Triton kernel from DeepGEMM currently has the fastest activation quantization (1 x block_size) + x_fp8, x_scale = fp8_blockwise_act_quant(x, block_size) + + # fbgemm currently has the fastest weight quantization (block_size x block_size) + weight_t_fp8, weight_t_scale = triton_quantize_fp8_block( + weight, + block_m=block_size, + block_k=block_size, + k_major=True, # For [M,K] -> [K,M] in column-major + ) + + # DeepGEMM for blockwise GEMM where activation has (1 x block_size) scaling granularity + # and weight has (block_size x block_size) scaling granularity. + out = scaled_mm_deep_gemm_128_1_128_128( + x_fp8, + x_scale, + weight_t_fp8, + weight_t_scale, + ) + ctx.save_for_backward(x, weight) + ctx.block_size = block_size + return out + + @staticmethod + def backward(ctx, grad_output): + x, weight = ctx.saved_tensors + block_size = ctx.block_size + + # left operand must be row-major + grad_output_fp8, grad_output_scale = fp8_blockwise_act_quant( + grad_output, + block_size, + ) + + # right operand must be column-major + weight_t_fp8, weight_t_scale = triton_quantize_fp8_block( + weight, + block_m=block_size, + block_k=block_size, + k_major=False, # For [M,K] -> [K,M] in row-major + ) + weight_t_fp8 = weight_t_fp8.t().contiguous().t() # To col-major + + # DeepGEMM for blockwise GEMM where left operand has (1 x block_size) scaling granularity + # and right operand has (block_size x block_size) scaling granularity. + # grad_x = grad_output @ weight.T + grad_x = scaled_mm_deep_gemm_128_1_128_128( + grad_output_fp8, + weight_t_fp8, + 1.0 / grad_output_scale, + 1.0 / weight_t_scale, + ) + + # left operand must be row-major + grad_output_t_fp8, grad_output_t_scale = fp8_blockwise_act_quant( + grad_output.t().contiguous(), + block_size, + ) + + # right operand must be column-major + x_fp8, x_scale = fp8_blockwise_act_quant( + x, + block_size, + ) + x_fp8 = x_fp8.t().contiguous().t() # To col-major + + # DeepGEMM for blockwise GEMM where both operands have (1 x block_size) scaling granularity. + # grad_weight = grad_output.T @ x + grad_weight = scaled_mm_deep_gemm_128_1_128_1( + grad_output_t_fp8, + x_fp8, + 1.0 / grad_output_t_scale, + 1.0 / x_scale, + ) + return grad_x, grad_weight, None, None + + +class Float8BlockwiseLinear(nn.Linear): """ Custom linear layer with support for quantized weights and optional bias. @@ -25,53 +119,60 @@ class BlockwiseQuantLinear(nn.Module): dtype (torch.dtype): Data type for the weights. Defaults to torch.float8_e4m3fn. """ - dtype = torch.bfloat16 + supported_dtypes = [ + torch.bfloat16, + ] def __init__( self, - in_features: int, - out_features: int, - bias: bool = False, + *args, block_size: int = 128, - dtype: torch.dtype = torch.float8_e4m3fn, + dtype=torch.bfloat16, + **kwargs, ): - super().__init__() - supported_dtypes = [ - torch.float8_e4m3fn, - torch.float8_e5m2, - ] - assert dtype in supported_dtypes, ( - f"Unsupported dtype: {dtype}. Supported dtypes: {supported_dtypes}" - ) - scale_in_features = (in_features + block_size - 1) // block_size - scale_out_features = (out_features + block_size - 1) // block_size - self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) - self.weight.scale = self.scale = nn.Parameter( - torch.empty(scale_out_features, scale_in_features, dtype=torch.float32) + super().__init__(*args, **kwargs) + + assert dtype in self.supported_dtypes, ( + f"Unsupported dtype: {dtype}. Supported dtypes: {self.supported_dtypes}" ) self.block_size = block_size - self.dtype - - if bias: - self.bias = nn.Parameter(torch.empty(out_features)) - else: - self.register_parameter("bias", None) + self.dtype = dtype def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for the custom linear layer. Args: - x (torch.Tensor): Input tensor. + x (torch.Tensor): input tensor. Returns: torch.Tensor: Transformed tensor after linear computation. """ - x, scale = fp8_blockwise_act_quant(x, self.block_size, self.dtype) - y = blockwise_fp8_gemm( - x, scale, self.weight, self.weight.scale, self.block_size - ) + return fp8_blockwise_mm.apply(x, self.weight, self.block_size) + + @classmethod + def from_float( + cls, + mod, + ): + assert mod.bias is None, "unsupported" + assert mod.in_features % 128 == 0, "unsupported" + assert mod.out_features % 128 == 0, "unsupported" + with torch.device("meta"): + new_mod = cls( + mod.in_features, + mod.out_features, + bias=False, + ) + new_mod.weight = mod.weight + new_mod.bias = mod.bias + return new_mod + + +class Float8BlockwiseLinearConfig(AOBaseConfig): + pass + - if self.bias is not None: - y += self.bias - return y +@register_quantize_module_handler(Float8BlockwiseLinearConfig) +def _deep_gemm_float8_inference_linear_transform(module, config): + return Float8BlockwiseLinear.from_float(module) diff --git a/torchao/prototype/blockwise_fp8/deep_gemm_utils.py b/torchao/prototype/blockwise_fp8/deep_gemm_utils.py new file mode 100644 index 0000000000..f2a6142ce0 --- /dev/null +++ b/torchao/prototype/blockwise_fp8/deep_gemm_utils.py @@ -0,0 +1,29 @@ +import sys + +import torch + +try: + import deep_gemm +except ImportError: + print("Please install deepgemm to use this feature") + sys.exit(0) + + +def scaled_mm_deep_gemm_128_1_128_128(a, b, a_scale, b_scale): + M, K = a.shape + N, K = b.shape + out = torch.empty((M, N), dtype=torch.bfloat16, device=a.device) + deep_gemm.gemm_fp8_fp8_bf16_nt((a, a_scale), (b, b_scale), out=out) + return out + + +def scaled_mm_deep_gemm_128_1_128_1(a, b, a_scale, b_scale): + M, K = a.shape + N, K = b.shape + # Note: the results from `wgrad_gemm_fp8_fp8_fp32_nt` are **accumulated** + # into this tensor. For now, we initialize with `zeros` to get correct + # numerics in toy examples. For a real use case, this will need to pass + # in the gradient tensor directly. + out = torch.zeros((M, N), dtype=torch.float, device=a.device) + deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt((a, a_scale), (b, b_scale), out=out) + return out diff --git a/torchao/prototype/blockwise_fp8/kernels.py b/torchao/prototype/blockwise_fp8/kernels.py index 3a296ace5a..5f123c5c14 100644 --- a/torchao/prototype/blockwise_fp8/kernels.py +++ b/torchao/prototype/blockwise_fp8/kernels.py @@ -12,6 +12,12 @@ import triton.language as tl from triton import Config +# try: +# from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import triton_quantize_fp8_block +# except ImportError: +# print("Please install fbgemm-gpu to use this feature") +# sys.exit(1) + # Original implementation at https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py fp8_gemm_configs = [