Skip to content

Commit d0631c0

Browse files
Add Float8BlockwiseLinear with Triton kernels for quantization and GEMMs
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
1 parent 0e00df3 commit d0631c0

File tree

10 files changed

+1154
-402
lines changed

10 files changed

+1154
-402
lines changed

benchmarks/benchmark_blockwise_scaled_linear_triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from triton.testing import do_bench
1414

1515
from torchao.float8.float8_utils import compute_error
16-
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
16+
from torchao.prototype.blockwise_fp8.kernels import (
1717
blockwise_fp8_gemm,
1818
fp8_blockwise_act_quant,
1919
fp8_blockwise_weight_quant,
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
import argparse
8+
import itertools
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from tqdm import tqdm
15+
from utils import benchmark_microseconds
16+
17+
from torchao.prototype.blockwise_fp8.kernels import (
18+
fp8_blockwise_act_quant,
19+
fp8_blockwise_weight_quant,
20+
torch_blockwise_scale_act_quant,
21+
torch_blockwise_scale_weight_quant,
22+
triton_quantize_fp8_block,
23+
)
24+
25+
device = torch.device("cuda")
26+
27+
# Needed since changing args to function causes recompiles
28+
torch._dynamo.config.cache_size_limit = 1000
29+
30+
31+
@dataclass(frozen=True)
32+
class ExperimentConfig:
33+
A_shape: tuple[int]
34+
block_m: int
35+
block_k: int
36+
37+
38+
@dataclass(frozen=True)
39+
class ExperimentResult:
40+
torch_us: float
41+
fbgemm_us: float
42+
deepgemm_us: float
43+
44+
45+
@dataclass(frozen=True)
46+
class Experiment:
47+
config: ExperimentConfig
48+
result: ExperimentResult
49+
50+
51+
def get_configs() -> List[ExperimentConfig]:
52+
A_shapes = [
53+
(1024, 1024),
54+
(2048, 2048),
55+
(4096, 4096),
56+
(8192, 8192),
57+
(16384, 16384),
58+
(32768, 32768),
59+
]
60+
block_m_opts = [1, 128]
61+
block_k_opts = [
62+
128,
63+
]
64+
configs = []
65+
for A_shape, block_m, block_k in itertools.product(
66+
A_shapes,
67+
block_m_opts,
68+
block_k_opts,
69+
):
70+
configs.append(
71+
ExperimentConfig(
72+
A_shape=A_shape,
73+
block_m=block_m,
74+
block_k=block_k,
75+
)
76+
)
77+
return configs
78+
79+
80+
def run_experiment(
81+
config: ExperimentConfig, args: argparse.Namespace
82+
) -> ExperimentResult:
83+
A = torch.randn(
84+
*config.A_shape,
85+
dtype=torch.bfloat16,
86+
device=device,
87+
)
88+
89+
# Torch and DeepGEMM implementations are specific to activation quantization (1 x block_size)
90+
# and weight quantization (block_size x block_size)
91+
if config.block_m == 1:
92+
torch_func = torch.compile(torch_blockwise_scale_act_quant)
93+
deepgemm_func = fp8_blockwise_act_quant
94+
else:
95+
torch_func = torch.compile(torch_blockwise_scale_weight_quant)
96+
deepgemm_func = fp8_blockwise_weight_quant
97+
98+
# Validate output shapes and strides
99+
torch_out, torch_scale = torch_func(A, tile_size=config.block_k)
100+
deepgemm_out, deepgemm_scale = deepgemm_func(A, block_size=config.block_k)
101+
fbgemm_out, fbgemm_scale = triton_quantize_fp8_block(
102+
A, block_m=config.block_m, block_k=config.block_k, k_major=True
103+
)
104+
assert torch_out.shape == deepgemm_out.shape == fbgemm_out.shape
105+
assert torch_out.stride() == deepgemm_out.stride() == fbgemm_out.stride()
106+
assert torch_scale.shape == deepgemm_scale.shape == fbgemm_scale.shape
107+
assert torch_scale.stride() == deepgemm_scale.stride() == fbgemm_scale.stride()
108+
109+
# Do benchmarking
110+
torch_us = benchmark_microseconds(torch_func, A, tile_size=config.block_k)
111+
deepgemm_us = benchmark_microseconds(
112+
fp8_blockwise_act_quant, A, block_size=config.block_k
113+
)
114+
fbgemm_us = benchmark_microseconds(
115+
triton_quantize_fp8_block,
116+
A,
117+
block_m=config.block_m,
118+
block_k=config.block_k,
119+
k_major=True,
120+
)
121+
122+
return ExperimentResult(
123+
torch_us=round(torch_us, 3),
124+
fbgemm_us=round(fbgemm_us, 3),
125+
deepgemm_us=round(deepgemm_us, 3),
126+
)
127+
128+
129+
def print_results(experiments: List[Experiment]):
130+
headers = [
131+
"A_shape",
132+
"block_shape",
133+
"torch_us",
134+
"fbgemm_us",
135+
"deepgemm_us",
136+
]
137+
rows = []
138+
for experiment in experiments:
139+
A_shape = f"({experiment.config.A_shape[0]}, {experiment.config.A_shape[1]})"
140+
block_shape = f"({experiment.config.block_m},{experiment.config.block_k})"
141+
rows.append(
142+
[
143+
A_shape,
144+
block_shape,
145+
experiment.result.torch_us,
146+
experiment.result.fbgemm_us,
147+
experiment.result.deepgemm_us,
148+
]
149+
)
150+
print(tabulate(rows, headers=headers))
151+
152+
153+
def main(args: argparse.Namespace):
154+
torch.random.manual_seed(123)
155+
configs = get_configs()
156+
results = []
157+
for config in tqdm(configs):
158+
result = run_experiment(config, args)
159+
results.append(Experiment(config=config, result=result))
160+
161+
# Use Tabulate to print results
162+
print_results(results)
163+
164+
165+
if __name__ == "__main__":
166+
arg_parser = argparse.ArgumentParser()
167+
arg_parser.add_argument("--compile", action="store_true")
168+
args = arg_parser.parse_args()
169+
main(args)

benchmarks/float8/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch.utils.benchmark as benchmark
1313
from torch.profiler import ProfilerActivity, profile
14+
from triton.testing import do_bench
1415

1516

1617
def profiler_output_to_filtered_time_by_kernel_name(
@@ -428,3 +429,12 @@ def do_benchmarks(
428429
tops_sec = float(tops) / time_sec
429430
pct_top_peak = tops_sec / peak_tops
430431
return time_sec, tops_sec, pct_top_peak
432+
433+
434+
def benchmark_microseconds(f, *args, warmup=25, rep=100, **kwargs):
435+
return (
436+
do_bench(
437+
lambda: f(*args, **kwargs), warmup=warmup, rep=rep, return_mode="median"
438+
)
439+
* 1e3
440+
)
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
import torch
9+
10+
triton = pytest.importorskip("triton", reason="Triton required to run this test")
11+
12+
from packaging import version
13+
from torchao.float8.float8_utils import compute_error
14+
from torchao.prototype.blockwise_fp8.kernels import (
15+
blockwise_fp8_gemm_1x128_1x128,
16+
blockwise_fp8_gemm_1x128_128x128,
17+
fp8_blockwise_act_quant,
18+
fp8_blockwise_weight_dequant,
19+
fp8_blockwise_weight_quant,
20+
torch_blockwise_scale_act_quant,
21+
torch_blockwise_scale_weight_quant,
22+
triton_quantize_fp8_block,
23+
)
24+
from torchao.testing.utils import skip_if_rocm
25+
26+
BLOCKWISE_SIZE_MNK = [
27+
(2, 512, 128),
28+
(3, 2048, 2048),
29+
(4, 3584, 640),
30+
(13, 8704, 8576),
31+
(26, 18944, 1664),
32+
(67, 6656, 1408),
33+
]
34+
35+
36+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
37+
@pytest.mark.parametrize("_, N, K", BLOCKWISE_SIZE_MNK)
38+
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
39+
def test_blockwise_quant_dequant(_, N, K, dtype):
40+
x = torch.randn(N, K).cuda()
41+
qx, s = fp8_blockwise_weight_quant(x, dtype=dtype)
42+
x_reconstructed = fp8_blockwise_weight_dequant(qx, s)
43+
sqnr = compute_error(x, x_reconstructed)
44+
assert sqnr >= 25.0, f"SQNR {sqnr:.2f} must be >= 25.0"
45+
46+
47+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
48+
@pytest.mark.skipif(
49+
version.parse(triton.__version__) < version.parse("3.3.0"),
50+
reason="Triton version < 3.3.0, test skipped",
51+
)
52+
@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK)
53+
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
54+
def test_blockwise_fp8_gemm_1x128_128x128(M, N, K, dtype):
55+
A = torch.randn(M, K).cuda()
56+
B = torch.randn(N, K).cuda()
57+
C = A @ B.T
58+
A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype)
59+
B_q, B_s = fp8_blockwise_weight_quant(B, dtype=dtype)
60+
C_q = blockwise_fp8_gemm_1x128_128x128(A_q, A_s, B_q, B_s)
61+
assert not C_q.isnan().any(), "C_q must not contain NaNs"
62+
sqnr = compute_error(C, C_q)
63+
assert sqnr >= 22.0, f"SQNR {sqnr:.2f} must be >= 22.0"
64+
65+
66+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
67+
@pytest.mark.skipif(
68+
version.parse(triton.__version__) < version.parse("3.3.0"),
69+
reason="Triton version < 3.3.0, test skipped",
70+
)
71+
@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK)
72+
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
73+
def test_blockwise_fp8_gemm_1x128_1x128(M, N, K, dtype):
74+
A = torch.randn(M, K).cuda()
75+
B = torch.randn(N, K).cuda()
76+
C = A @ B.T
77+
A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype)
78+
B_q, B_s = fp8_blockwise_act_quant(B, dtype=dtype)
79+
C_q = blockwise_fp8_gemm_1x128_1x128(A_q, A_s, B_q, B_s)
80+
assert not C_q.isnan().any(), "C_q must not contain NaNs"
81+
sqnr = compute_error(C, C_q)
82+
assert sqnr >= 22.0, f"SQNR {sqnr:.2f} must be >= 22.0"
83+
84+
85+
@skip_if_rocm("ROCm not supported")
86+
@pytest.mark.parametrize("tile_size", [128, 256])
87+
@pytest.mark.parametrize("test_eps", [True, False])
88+
def test_triton_quantize_fp8_act_quant(tile_size: int, test_eps: bool):
89+
device = "cuda"
90+
M, K = 256, 256
91+
x = torch.randn(M, K, device=device)
92+
93+
# set one scaling block to 0s, so if nan guards/EPS are not applied, the
94+
# quantized tensor will have NaNs due to division by 0
95+
if test_eps:
96+
x[0, :tile_size] = 0.0
97+
98+
# Get the quantized tensor and scales using triton implementation
99+
# Use block_m=1 to match the narrow tiles (1 x tile_size) in the reference implementation
100+
triton_fp8, triton_scale = triton_quantize_fp8_block(
101+
x, block_m=1, block_k=tile_size
102+
)
103+
assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs"
104+
105+
# Get the quantized tensor and scales using reference implementation
106+
ref_fp8, ref_scale = torch_blockwise_scale_act_quant(x, tile_size=tile_size)
107+
assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs"
108+
109+
# Convert both to float32 for comparison
110+
triton_fp32 = triton_fp8.to(torch.float32)
111+
ref_fp32 = ref_fp8.to(torch.float32)
112+
113+
# Check that the quantized tensors are close
114+
# Note: We use a relatively high tolerance because the implementations might have
115+
# slight differences in how they handle edge cases, rounding, etc.
116+
assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-2, atol=1e-2), (
117+
f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}"
118+
)
119+
120+
# Check that the scales are close
121+
# Note: The scales might be stored differently (reciprocal vs. direct), so we need to
122+
# be careful about how we compare them
123+
124+
# In triton_quantize_fp8_block, scales are stored as reciprocals (1/scale)
125+
# In torch_blockwise_scale_act_quant, scales are stored directly
126+
# So we need to take the reciprocal of one of them for comparison
127+
128+
# Reshape triton_scale to match ref_scale shape for comparison
129+
triton_scale_reshaped = triton_scale.reshape(M, -1)
130+
131+
# Compare reciprocal of triton_scale with ref_scale
132+
assert torch.allclose(
133+
1.0 / triton_scale_reshaped, ref_scale, rtol=1e-2, atol=1e-2
134+
), (
135+
f"Scales differ: max diff = {(1.0 / triton_scale_reshaped - ref_scale).abs().max().item()}"
136+
)
137+
138+
139+
@skip_if_rocm("ROCm not supported")
140+
@pytest.mark.parametrize("tile_size", [128, 256])
141+
@pytest.mark.parametrize("test_eps", [True, False])
142+
def test_triton_quantize_fp8_weight_quant(tile_size: int, test_eps: bool):
143+
device = "cuda"
144+
# Make sure dimensions are multiples of tile_size for clean comparison
145+
M = tile_size * 2
146+
K = tile_size * 2
147+
x = torch.randn(M, K, device=device)
148+
149+
# set one scaling block to 0s, so if nan guards/EPS are not applied, the
150+
# quantized tensor will have NaNs due to division by 0
151+
if test_eps:
152+
x[:tile_size, :tile_size] = 0.0
153+
154+
# Get the quantized tensor and scales using triton implementation
155+
triton_fp8, triton_scale = triton_quantize_fp8_block(
156+
x, block_m=tile_size, block_k=tile_size
157+
)
158+
assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs"
159+
160+
# Get the quantized tensor and scales using reference implementation
161+
ref_fp8, ref_scale = torch_blockwise_scale_weight_quant(x, tile_size=tile_size)
162+
assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs"
163+
164+
# Convert both to float32 for comparison
165+
triton_fp32 = triton_fp8.to(torch.float32)
166+
ref_fp32 = ref_fp8.to(torch.float32)
167+
168+
# Check that the quantized tensors are close
169+
assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-2, atol=1e-2), (
170+
f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}"
171+
)
172+
173+
# Check that the scales are close
174+
# In triton_quantize_fp8_block, scales are stored as reciprocals (1/scale)
175+
# In torch_blockwise_scale_weight_quant, scales are stored directly
176+
177+
# Compare reciprocal of triton_scale with ref_scale
178+
assert torch.allclose(1.0 / triton_scale, ref_scale, rtol=1e-2, atol=1e-2), (
179+
f"Scales differ: max diff = {(1.0 / triton_scale - ref_scale).abs().max().item()}"
180+
)

0 commit comments

Comments
 (0)