Skip to content

Commit f28cd30

Browse files
authored
feat: AutoDeploy fp8 quantization support for bmm (#3849)
Signed-off-by: Wei-Ming Chen <[email protected]>
1 parent 6e48ac2 commit f28cd30

File tree

8 files changed

+493
-37
lines changed

8 files changed

+493
-37
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Definition of the quant module that can be used for PTQ."""
22

3+
import warnings
34
from typing import Optional
45

56
import torch
7+
from flashinfer import bmm_fp8
68
from torch import nn
79

810
from tensorrt_llm._torch.autotuner import autotune
@@ -222,7 +224,90 @@ def fp4_linear_fake(
222224
return torch.ops.aten.linear(input, weight_fp4.repeat(1, 2).to(input.dtype), bias)
223225

224226

225-
QUANT_OPS = [
227+
def is_column_major(tensor):
228+
rows, _ = tensor.shape[-2:]
229+
strides = tensor.stride()
230+
return strides[-2] == 1 and strides[-1] == rows
231+
232+
233+
@torch.library.custom_op("auto_deploy::torch_quant_fp8_bmm", mutates_args=())
234+
def fp8_bmm(
235+
input: torch.Tensor,
236+
mat2: torch.Tensor,
237+
input_scale: torch.Tensor,
238+
weight_scale: torch.Tensor,
239+
) -> torch.Tensor:
240+
"""FP8 BMM op similar to torch.bmm.
241+
242+
Args:
243+
input: unquantized input tensor with shape (B, M, K)
244+
mat2: weight tensor with shape (B, K, N), with dtype torch.float8_e4m3fn,
245+
or torch.float16, or torch.bfloat16
246+
input_scale: a scalar tensor - the inverse scale for input quantization
247+
weight_scale: a scalar tensor - the inverse scale for weight quantization
248+
249+
Returns:
250+
The BMM output with shape (B, M, N) and the original dtype as the input.
251+
"""
252+
# Ensure input is contiguous
253+
input = input.contiguous()
254+
original_input_dtype = input.dtype
255+
256+
# Convert input to fp8 using provided scale
257+
if input.dtype in [torch.float16, torch.bfloat16]:
258+
input_fp8 = _to_fp8(input, input_scale)
259+
else:
260+
assert input.dtype == torch.float8_e4m3fn
261+
input_fp8 = input
262+
263+
# Convert mat2 to fp8 using provided scale
264+
if mat2.dtype in [torch.float16, torch.bfloat16]:
265+
mat2_fp8 = _to_fp8(mat2, weight_scale)
266+
else:
267+
assert mat2.dtype == torch.float8_e4m3fn
268+
mat2_fp8 = mat2
269+
270+
# Ensure mat2 is contiguous in column-major format only if needed
271+
# Check if the tensor is already contiguous when transposed (i.e., already column-major)
272+
if not is_column_major(mat2_fp8):
273+
warnings.warn(
274+
"mat2 is not in column-major format, transposing it, this will cause performance degradation."
275+
)
276+
mat2_fp8 = mat2_fp8.transpose(-2, -1).contiguous().transpose(-2, -1)
277+
278+
# Get dimensions
279+
B, M, K = input.shape
280+
B2, K2, N = mat2_fp8.shape
281+
assert B == B2, f"Batch dimensions must match: {B} vs {B2}"
282+
assert K == K2, f"Inner dimensions must match: {K} vs {K2}"
283+
284+
output = torch.empty((B, M, N), dtype=original_input_dtype, device=input.device)
285+
bmm_fp8(
286+
input_fp8, mat2_fp8, input_scale.float(), weight_scale.float(), original_input_dtype, output
287+
)
288+
289+
return output
290+
291+
292+
@fp8_bmm.register_fake
293+
def fp8_bmm_fake(
294+
input: torch.Tensor,
295+
mat2: torch.Tensor,
296+
input_scale: torch.Tensor,
297+
weight_scale: torch.Tensor,
298+
) -> torch.Tensor:
299+
"""Fake implementation of fp8_bmm for testing and tracing."""
300+
# Use standard bmm
301+
return torch.bmm(input.to(torch.float), mat2.to(torch.float)).to(input.dtype)
302+
303+
304+
QUANT_LINEAR_OPS = [
226305
torch.ops.auto_deploy.torch_quant_fp8_linear,
227306
torch.ops.auto_deploy.torch_quant_fp4_linear,
228307
]
308+
309+
QUANT_BMM_OPS = [
310+
torch.ops.auto_deploy.torch_quant_fp8_bmm,
311+
]
312+
313+
QUANT_OPS = QUANT_LINEAR_OPS + QUANT_BMM_OPS

tensorrt_llm/_torch/auto_deploy/transformations/library/fusion.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node
4141
sizes_unfused = [p.size(0) for p in params_unfused]
4242
key_fused = f"fused_weight_{idx}"
4343

44-
quantization_impl = QuantizationImpl.create(linear_nodes[0])
44+
quantization_impls = [QuantizationImpl.create(n) for n in linear_nodes]
4545

4646
def fuse_weights(tensors: List[torch.Tensor]) -> torch.Tensor:
4747
"""Fuse weights of linear nodes."""
@@ -51,17 +51,20 @@ def split_output(tensor: torch.Tensor) -> Tuple[torch.Tensor, ...]:
5151
"""Split the output tensor of the fused linear node to obtain the original outputs."""
5252
return tuple(t.contiguous() for t in torch.split(tensor, sizes_unfused, dim=-1))
5353

54-
if quantization_impl:
54+
if all(
55+
q is not None and quantization_impls[0].target_op() == q.target_op()
56+
for q in quantization_impls
57+
):
5558
scales = {}
5659
for weight_key in keys_unfused:
5760
key = weight_key.rsplit(".", 1)[0]
5861

59-
for scale_name in quantization_impl.scale_names():
62+
for scale_name in quantization_impls[0].scale_names():
6063
buffer_name = key + "." + scale_name
6164
scales.setdefault(scale_name, []).append(gm.get_buffer(buffer_name))
6265

6366
try:
64-
weights_fused, buffer_fused = quantization_impl.fuse_linear_weights(
67+
weights_fused, buffer_fused = quantization_impls[0].fuse_linear_weights(
6568
params_unfused, **scales
6669
)
6770
except NotImplementedError as e:
@@ -73,8 +76,11 @@ def split_output(tensor: torch.Tensor) -> Tuple[torch.Tensor, ...]:
7376
fused_buffer_name = key_fused + "_" + scale_name
7477
gm.register_buffer(fused_buffer_name, buffer)
7578

76-
else:
79+
elif all(q is None for q in quantization_impls):
7780
param_fused = nn.Parameter(fuse_weights([gm.get_parameter(k) for k in keys_unfused]))
81+
else:
82+
ad_logger.warning(f"Cannot fuse ops {keys_unfused} for mixed-precision linear nodes.")
83+
return
7884

7985
setattr(gm, key_fused, param_fused)
8086

@@ -83,8 +89,8 @@ def split_output(tensor: torch.Tensor) -> Tuple[torch.Tensor, ...]:
8389

8490
with gm.graph.inserting_before(linear_nodes[0]):
8591
get_param_node = gm.graph.get_attr(key_fused, torch.Tensor)
86-
if quantization_impl:
87-
for scale_name in quantization_impl.scale_names():
92+
if quantization_impls[0]:
93+
for scale_name in quantization_impls[0].scale_names():
8894
# Creates new nodes for the fused scales so the unfused linear ops can be fully erased.
8995
fused_kwargs[scale_name] = gm.graph.create_node(
9096
"get_attr", key_fused + "_" + scale_name

tensorrt_llm/_torch/auto_deploy/transformations/library/quantization.py

Lines changed: 116 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ...utils.node_utils import (
1010
extract_param_names_from_lin_node,
1111
get_quantization_params_from_linear_node,
12+
is_bmm_op,
1213
is_linear_op,
1314
is_match,
1415
)
@@ -81,8 +82,95 @@ def _insert_quantized_linear(
8182
node.kwargs = {**node.kwargs, **scales}
8283

8384

85+
def _insert_quantized_bmm(
86+
gm: GraphModule,
87+
node: Node,
88+
quantization_impl: QuantizationImpl,
89+
is_quantized_graph: bool = False,
90+
):
91+
"""Replaces the bmm node with a new quantized bmm node."""
92+
weight_node = node.args[1]
93+
94+
# Weight is a parameter
95+
if weight_node.op == "get_attr":
96+
# Handle parameter tensor
97+
param_name = weight_node.target
98+
original_weight = gm.get_parameter(param_name)
99+
weight_shape = original_weight.shape
100+
101+
# Quantize the weight
102+
new_param = nn.Parameter(
103+
quantization_impl.quantize_weight(original_weight), requires_grad=False
104+
)
105+
106+
# Update the parameter in the model
107+
modname, _, attrname = param_name.rpartition(".")
108+
submod = gm.get_submodule(modname)
109+
setattr(submod, attrname, new_param)
110+
111+
# Register load state dict hook
112+
gm._register_load_state_dict_pre_hook(
113+
partial(quantization_impl.load_hook, weight_name=param_name)
114+
)
115+
if quantization_impl.post_load_hook:
116+
gm.register_load_state_dict_post_hook(
117+
partial(quantization_impl.post_load_hook, weight_name=param_name)
118+
)
119+
120+
# Setup scale names and target module for parameter case
121+
def get_scale_name(scale_name):
122+
return attrname + "_" + scale_name
123+
124+
scale_target_module = submod
125+
scale_name_prefix = f"{modname}."
126+
127+
# Weight is a dynamic tensor
128+
elif hasattr(weight_node, "meta") and "val" in weight_node.meta:
129+
weight_shape = weight_node.meta["val"].shape
130+
131+
# Create a unique identifier for this dynamic weight node
132+
node_id = f"bmm_dynamic_{id(node)}"
133+
134+
# Setup scale names and target module for dynamic case
135+
def get_scale_name(scale_name):
136+
return f"{node_id}_{scale_name}"
137+
138+
scale_target_module = gm # Register in root module
139+
scale_name_prefix = ""
140+
141+
ad_logger.info(f"Quantized BMM with dynamic weight tensor for node {node}")
142+
else:
143+
# If we can't determine the shape, skip quantization
144+
ad_logger.warning(
145+
f"BMM weight is dynamic tensor without shape metadata, skipping quantization for node {node}"
146+
)
147+
return
148+
149+
# Common logic for both parameter and dynamic tensor cases
150+
# Register scales in the target module
151+
for scale_name, scale in quantization_impl.default_scales(weight_shape).items():
152+
scale_buffer_name = get_scale_name(scale_name)
153+
scale_target_module.register_buffer(scale_buffer_name, scale)
154+
155+
# Change node target to quantized bmm op
156+
node.target = quantization_impl.target_op()
157+
158+
# Insert scale nodes
159+
with gm.graph.inserting_before(node):
160+
scales = {}
161+
for scale_name in quantization_impl.scale_names():
162+
scale_buffer_name = get_scale_name(scale_name)
163+
scales[scale_name] = gm.graph.create_node(
164+
"get_attr", f"{scale_name_prefix}{scale_buffer_name}"
165+
)
166+
167+
# Update node arguments and kwargs
168+
scale_values = [scales[scale_name] for scale_name in quantization_impl.scale_names()]
169+
node.args = (*node.args, *scale_values)
170+
171+
84172
def quantize(gm: GraphModule, quant_config: Dict[str, Any]):
85-
"""Quantize the GraphModule and replace linear with quantized linear."""
173+
"""Quantize the GraphModule and replace linear and bmm with quantized versions."""
86174
# extract info from quant_config
87175
is_quant_graph = is_quantized_graph(gm)
88176
quant_algo = quant_config.get("quant_algo")
@@ -93,28 +181,44 @@ def quantize(gm: GraphModule, quant_config: Dict[str, Any]):
93181
ad_logger.info("No quantization to do.")
94182
return gm
95183

96-
# tracking quantized linears in the graph
97-
quantized_nodes: Dict[str, int] = defaultdict(lambda: 0)
184+
# tracking quantized operations in the graph
185+
quantized_nodes: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
98186
for n in gm.graph.nodes:
99187
# check if we should skip this node
100-
if is_match(n, skip) or not is_linear_op(n, include_quantization=False):
188+
if is_match(n, skip):
101189
continue
102190

103-
# get per-layer quantization format from the node
104-
quant_algo_n: str = get_quantization_from_linear_node(n) if is_quant_graph else quant_algo
105-
if not quant_algo_n:
106-
continue
191+
# Process linear operations
192+
if is_linear_op(n, include_quantization=False):
193+
# get per-layer quantization format from the node
194+
quant_algo_n: str = (
195+
get_quantization_from_linear_node(n) if is_quant_graph else quant_algo
196+
)
197+
if not quant_algo_n:
198+
continue
199+
200+
# insert quantized linear node
201+
_insert_quantized_linear(gm, n, QuantizationImpl.create(quant_algo_n), is_quant_graph)
202+
quantized_nodes[quant_algo_n]["linear"] += 1
107203

108-
# insert quantized linear node
109-
_insert_quantized_linear(gm, n, QuantizationImpl.create(quant_algo_n), is_quant_graph)
110-
quantized_nodes[quant_algo_n] += 1
204+
# Process BMM operations
205+
elif is_bmm_op(n):
206+
if not quant_algo:
207+
continue
208+
209+
# insert quantized bmm node
210+
_insert_quantized_bmm(
211+
gm, n, QuantizationImpl.create(quant_algo, is_bmm=True), is_quant_graph
212+
)
213+
quantized_nodes[quant_algo]["bmm"] += 1
111214

112215
if is_quant_graph:
113216
remove_output_quantizers(gm)
114217

115218
gm = canonicalize_graph(gm)
116219
for quant_algo in quantized_nodes:
117-
ad_logger.info(f"Found {quantized_nodes[quant_algo]} {quant_algo} quantized nodes.")
220+
for op_type, count in quantized_nodes[quant_algo].items():
221+
ad_logger.info(f"Found {count} {quant_algo} quantized {op_type} nodes.")
118222
ad_logger.debug("After quantization: " + str(gm))
119223

120224
return gm

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch._ops import OpOverload, OpOverloadPacket
99
from torch.fx import Graph, GraphModule, Node
1010

11-
from ..custom_ops.quant import QUANT_OPS
11+
from ..custom_ops.quant import QUANT_BMM_OPS, QUANT_LINEAR_OPS
1212
from .logger import ad_logger
1313

1414
try:
@@ -226,10 +226,20 @@ def is_linear_op(node: Node, include_quantization: bool = False) -> bool:
226226
}
227227

228228
if include_quantization:
229-
lin_ops.update(QUANT_OPS)
229+
lin_ops.update(QUANT_LINEAR_OPS)
230230
return is_op(node, lin_ops)
231231

232232

233+
def is_bmm_op(node: Node, include_quantization: bool = False) -> bool:
234+
"""Check if the node is a distributed op."""
235+
dist_ops = {torch.ops.aten.bmm}
236+
237+
if include_quantization:
238+
dist_ops.update(QUANT_BMM_OPS)
239+
240+
return is_op(node, dist_ops)
241+
242+
233243
def is_dist_op(node: Node) -> bool:
234244
"""Check if the node is a distributed op."""
235245
dist_ops = {

0 commit comments

Comments
 (0)