3333)
3434from executorch .backends .cadence .aot .quantizer .utils import (
3535 check_out_zero_point_is_min_range ,
36+ copy_node_metadata ,
3637 create_zero_bias_int32 ,
3738 find_sequential_partitions_aten ,
3839 get_conv_args ,
@@ -159,6 +160,20 @@ def get_args_and_kwargs_layer_norm(
159160 ),
160161 {"dtype" : torch .float32 },
161162 )
163+ if len (inputs_inputs ) > 0 :
164+ if "val" in inputs_inputs [0 ].meta :
165+ fake_mode = inputs_inputs [0 ].meta ["val" ].fake_mode
166+ if fake_mode is not None :
167+ with fake_mode :
168+ fake_weight = torch .full (
169+ other_inputs [0 ], 1 , dtype = torch .float32
170+ )
171+ weight .meta ["val" ] = fake_weight
172+ else :
173+ weight .meta ["val" ] = torch .full (
174+ other_inputs [0 ], 1 , dtype = torch .float32
175+ )
176+ copy_node_metadata (weight , inputs_inputs [0 ])
162177
163178 bias = other_inputs [2 ] if len (other_inputs ) > 2 else None
164179
@@ -171,6 +186,18 @@ def get_args_and_kwargs_layer_norm(
171186 ),
172187 {"dtype" : torch .float32 },
173188 )
189+ if len (inputs_inputs ) > 0 :
190+ if "val" in inputs_inputs [0 ].meta :
191+ fake_mode = inputs_inputs [0 ].meta ["val" ].fake_mode
192+ if fake_mode is not None :
193+ with fake_mode :
194+ fake_bias = torch .full (other_inputs [0 ], 0 , dtype = torch .float32 )
195+ bias .meta ["val" ] = fake_bias
196+ else :
197+ bias .meta ["val" ] = torch .full (
198+ other_inputs [0 ], 0 , dtype = torch .float32
199+ )
200+ copy_node_metadata (bias , inputs_inputs [0 ])
174201
175202 # Make the args and kwargs for the replacement op
176203 args = tuple (inputs_inputs + [scale , zero_point ])
@@ -346,6 +373,16 @@ def get_args_and_kwargs_softmax(
346373 ),
347374 {"dtype" : torch .int32 },
348375 )
376+ if len (inputs_inputs ) > 0 :
377+ if "val" in inputs_inputs [0 ].meta :
378+ fake_mode = inputs_inputs [0 ].meta ["val" ].fake_mode
379+ if fake_mode is not None :
380+ with fake_mode :
381+ fake_mask = torch .full (mask_shape , 0.0 , dtype = torch .int32 )
382+ mask_tensor .meta ["val" ] = fake_mask
383+ else :
384+ mask_tensor .meta ["val" ] = torch .full (mask_shape , 0.0 , dtype = torch .int32 )
385+ copy_node_metadata (mask_tensor , inputs_inputs [0 ])
349386 # Make the scale and zero_point tensors
350387 in_scale = dequants_inputs [0 ].args [1 ]
351388 in_zero_point = dequants_inputs [0 ].args [2 ]
@@ -395,10 +432,13 @@ def get_args_and_kwargs_mixed_w8a32_conv(
395432 torch .ops .aten .permute .default ,
396433 (other_inputs [0 ], [0 , 2 , 1 ]), # NCL -> NLC
397434 )
435+ copy_node_metadata (transposed_inputs , other_inputs [0 ])
436+
398437 transposed_weights = graph_module .graph .call_function (
399438 torch .ops .aten .permute .default ,
400439 (weights_inputs [0 ], [2 , 0 , 1 ]), # NCL -> LNC
401440 )
441+ copy_node_metadata (transposed_weights , weights_inputs [0 ])
402442
403443 args = (
404444 transposed_inputs ,
@@ -582,6 +622,26 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
582622 torch .ops .aten .transpose .int ,
583623 (weights_inputs [0 ], 0 , 1 ),
584624 )
625+ if "val" in weights_inputs [0 ].meta :
626+ original_val = weights_inputs [0 ].meta ["val" ]
627+ fake_mode = original_val .fake_mode
628+ if fake_mode is not None :
629+ with fake_mode :
630+ transposed_val = torch .ops .aten .transpose .int (
631+ original_val , 0 , 1
632+ )
633+ transposed_weights .meta ["val" ] = transposed_val
634+ else :
635+ transposed_shape = list (original_val .shape )
636+ transposed_shape [0 ], transposed_shape [1 ] = (
637+ transposed_shape [1 ],
638+ transposed_shape [0 ],
639+ )
640+ transposed_weights .meta ["val" ] = torch .zeros (
641+ transposed_shape , dtype = original_val .dtype
642+ )
643+ copy_node_metadata (transposed_weights , weights_inputs [0 ])
644+
585645 # Call linear with transposed weight
586646 args , kwargs = get_args_and_kwargs_linear (
587647 graph_module ,
@@ -654,6 +714,19 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
654714
655715 legalize_graph (graph_module )
656716 graph_module .graph .eliminate_dead_code ()
717+ nodes_list = list (graph_module .graph .nodes )
718+
719+ if len (nodes_list ) > 0 and nodes_list [- 1 ].op != "output" :
720+ output_nodes = [n for n in nodes_list if n .op == "output" ]
721+ output_arg = output_nodes [0 ].args [0 ]
722+ original_meta = output_nodes [0 ].meta .copy ()
723+
724+ for out_node in output_nodes :
725+ graph_module .graph .erase_node (out_node )
726+
727+ new_output_node = graph_module .graph .output (output_arg )
728+ new_output_node .meta .update (original_meta )
729+
657730 graph_module .recompile ()
658731 return PassResult (graph_module , True )
659732
0 commit comments