99from ...utils .node_utils import (
1010 extract_param_names_from_lin_node ,
1111 get_quantization_params_from_linear_node ,
12+ is_bmm_op ,
1213 is_linear_op ,
1314 is_match ,
1415)
@@ -81,8 +82,95 @@ def _insert_quantized_linear(
8182 node .kwargs = {** node .kwargs , ** scales }
8283
8384
85+ def _insert_quantized_bmm (
86+ gm : GraphModule ,
87+ node : Node ,
88+ quantization_impl : QuantizationImpl ,
89+ is_quantized_graph : bool = False ,
90+ ):
91+ """Replaces the bmm node with a new quantized bmm node."""
92+ weight_node = node .args [1 ]
93+
94+ # Weight is a parameter
95+ if weight_node .op == "get_attr" :
96+ # Handle parameter tensor
97+ param_name = weight_node .target
98+ original_weight = gm .get_parameter (param_name )
99+ weight_shape = original_weight .shape
100+
101+ # Quantize the weight
102+ new_param = nn .Parameter (
103+ quantization_impl .quantize_weight (original_weight ), requires_grad = False
104+ )
105+
106+ # Update the parameter in the model
107+ modname , _ , attrname = param_name .rpartition ("." )
108+ submod = gm .get_submodule (modname )
109+ setattr (submod , attrname , new_param )
110+
111+ # Register load state dict hook
112+ gm ._register_load_state_dict_pre_hook (
113+ partial (quantization_impl .load_hook , weight_name = param_name )
114+ )
115+ if quantization_impl .post_load_hook :
116+ gm .register_load_state_dict_post_hook (
117+ partial (quantization_impl .post_load_hook , weight_name = param_name )
118+ )
119+
120+ # Setup scale names and target module for parameter case
121+ def get_scale_name (scale_name ):
122+ return attrname + "_" + scale_name
123+
124+ scale_target_module = submod
125+ scale_name_prefix = f"{ modname } ."
126+
127+ # Weight is a dynamic tensor
128+ elif hasattr (weight_node , "meta" ) and "val" in weight_node .meta :
129+ weight_shape = weight_node .meta ["val" ].shape
130+
131+ # Create a unique identifier for this dynamic weight node
132+ node_id = f"bmm_dynamic_{ id (node )} "
133+
134+ # Setup scale names and target module for dynamic case
135+ def get_scale_name (scale_name ):
136+ return f"{ node_id } _{ scale_name } "
137+
138+ scale_target_module = gm # Register in root module
139+ scale_name_prefix = ""
140+
141+ ad_logger .info (f"Quantized BMM with dynamic weight tensor for node { node } " )
142+ else :
143+ # If we can't determine the shape, skip quantization
144+ ad_logger .warning (
145+ f"BMM weight is dynamic tensor without shape metadata, skipping quantization for node { node } "
146+ )
147+ return
148+
149+ # Common logic for both parameter and dynamic tensor cases
150+ # Register scales in the target module
151+ for scale_name , scale in quantization_impl .default_scales (weight_shape ).items ():
152+ scale_buffer_name = get_scale_name (scale_name )
153+ scale_target_module .register_buffer (scale_buffer_name , scale )
154+
155+ # Change node target to quantized bmm op
156+ node .target = quantization_impl .target_op ()
157+
158+ # Insert scale nodes
159+ with gm .graph .inserting_before (node ):
160+ scales = {}
161+ for scale_name in quantization_impl .scale_names ():
162+ scale_buffer_name = get_scale_name (scale_name )
163+ scales [scale_name ] = gm .graph .create_node (
164+ "get_attr" , f"{ scale_name_prefix } { scale_buffer_name } "
165+ )
166+
167+ # Update node arguments and kwargs
168+ scale_values = [scales [scale_name ] for scale_name in quantization_impl .scale_names ()]
169+ node .args = (* node .args , * scale_values )
170+
171+
84172def quantize (gm : GraphModule , quant_config : Dict [str , Any ]):
85- """Quantize the GraphModule and replace linear with quantized linear ."""
173+ """Quantize the GraphModule and replace linear and bmm with quantized versions ."""
86174 # extract info from quant_config
87175 is_quant_graph = is_quantized_graph (gm )
88176 quant_algo = quant_config .get ("quant_algo" )
@@ -93,28 +181,44 @@ def quantize(gm: GraphModule, quant_config: Dict[str, Any]):
93181 ad_logger .info ("No quantization to do." )
94182 return gm
95183
96- # tracking quantized linears in the graph
97- quantized_nodes : Dict [str , int ] = defaultdict (lambda : 0 )
184+ # tracking quantized operations in the graph
185+ quantized_nodes : Dict [str , Dict [ str , int ]] = defaultdict (lambda : defaultdict ( int ) )
98186 for n in gm .graph .nodes :
99187 # check if we should skip this node
100- if is_match (n , skip ) or not is_linear_op ( n , include_quantization = False ) :
188+ if is_match (n , skip ):
101189 continue
102190
103- # get per-layer quantization format from the node
104- quant_algo_n : str = get_quantization_from_linear_node (n ) if is_quant_graph else quant_algo
105- if not quant_algo_n :
106- continue
191+ # Process linear operations
192+ if is_linear_op (n , include_quantization = False ):
193+ # get per-layer quantization format from the node
194+ quant_algo_n : str = (
195+ get_quantization_from_linear_node (n ) if is_quant_graph else quant_algo
196+ )
197+ if not quant_algo_n :
198+ continue
199+
200+ # insert quantized linear node
201+ _insert_quantized_linear (gm , n , QuantizationImpl .create (quant_algo_n ), is_quant_graph )
202+ quantized_nodes [quant_algo_n ]["linear" ] += 1
107203
108- # insert quantized linear node
109- _insert_quantized_linear (gm , n , QuantizationImpl .create (quant_algo_n ), is_quant_graph )
110- quantized_nodes [quant_algo_n ] += 1
204+ # Process BMM operations
205+ elif is_bmm_op (n ):
206+ if not quant_algo :
207+ continue
208+
209+ # insert quantized bmm node
210+ _insert_quantized_bmm (
211+ gm , n , QuantizationImpl .create (quant_algo , is_bmm = True ), is_quant_graph
212+ )
213+ quantized_nodes [quant_algo ]["bmm" ] += 1
111214
112215 if is_quant_graph :
113216 remove_output_quantizers (gm )
114217
115218 gm = canonicalize_graph (gm )
116219 for quant_algo in quantized_nodes :
117- ad_logger .info (f"Found { quantized_nodes [quant_algo ]} { quant_algo } quantized nodes." )
220+ for op_type , count in quantized_nodes [quant_algo ].items ():
221+ ad_logger .info (f"Found { count } { quant_algo } quantized { op_type } nodes." )
118222 ad_logger .debug ("After quantization: " + str (gm ))
119223
120224 return gm
0 commit comments