14
14
15
15
import torch
16
16
import torch .nn as nn
17
- import types
18
17
19
18
from .quant_config import QuantMode , get_hqt_config , ScaleFormat
20
19
from .._core .quant_dequant import QuantDequant as qdq
@@ -75,10 +74,13 @@ def set_attrs_from_orig_model(cls_instance, mod, mod_extra_config, *func_names):
75
74
cls_instance .fake_quant = config .cfg ["fake_quant" ]
76
75
cls_instance .use_qdq = config .cfg ["use_qdq" ]
77
76
cls_instance .scale_format = config .cfg ["scale_format" ]
78
- cls_instance .forward_orig = types .MethodType (mod .forward .__func__ , cls_instance )
77
+ # store original module in order to invoke its functions during measurements.
78
+ # this may be omitted of torch remove the related validation from dynamo. see SW-187731.
79
+ cls_instance .__dict__ ["orig_mod" ] = mod
80
+ cls_instance .forward_orig = mod .forward
79
81
if func_names is not None :
80
82
for func in func_names :
81
- setattr (cls_instance , func , types . MethodType ( getattr (mod , func ). __func__ , cls_instance ))
83
+ setattr (cls_instance , func , getattr (mod , func ))
82
84
83
85
84
86
def get_current_repr (cls_instance , * member_names ):
@@ -143,7 +145,7 @@ def forward_qdq(self, input, other):
143
145
144
146
def forward_measure (self , input , other ):
145
147
measure_input ((input , other ), observer = self ._mod_extra_config .inputs )
146
- output = self .forward_orig (input , other )
148
+ output = self .orig_mod (input , other )
147
149
measure_output ((output ,), self ._mod_extra_config .outputs )
148
150
return output
149
151
@@ -212,7 +214,7 @@ def forward_quant(self, input):
212
214
213
215
def forward_measure (self , input ):
214
216
measure_input ((input ,), observer = self ._mod_extra_config .inputs )
215
- output = self .forward_orig (input )
217
+ output = self .orig_mod (input )
216
218
measure_output ((output ,), self ._mod_extra_config .outputs )
217
219
return output
218
220
@@ -520,7 +522,7 @@ def forward_qdq(self, input):
520
522
output = torch .matmul (qinput , qweight )
521
523
522
524
if self .gather_output :
523
- output = self .collective_func (output )
525
+ output = self .orig_mod . collective_func (output )
524
526
return self .post_all_reduce (output )
525
527
526
528
def forward_quant (self , input ):
@@ -532,15 +534,15 @@ def forward_quant(self, input):
532
534
scale_other_inv = self .scale_weight )
533
535
dqoutput = self .dequant_output (output )
534
536
if self .gather_output :
535
- dqoutput = self .collective_func (dqoutput )
537
+ dqoutput = self .orig_mod . collective_func (dqoutput )
536
538
return self .post_all_reduce (dqoutput )
537
539
538
540
def forward_measure (self , input ):
539
541
measure_input ((input ,), observer = self ._mod_extra_config .inputs )
540
542
output = torch .matmul (input , self .weight .transpose (- 1 , - 2 ))
541
543
measure_output ((output ,), self ._mod_extra_config .outputs )
542
544
if self .gather_output :
543
- output = self .collective_func (output )
545
+ output = self .orig_mod . collective_func (output )
544
546
return self .post_all_reduce (output )
545
547
546
548
def post_all_reduce (self , output ):
@@ -689,7 +691,7 @@ def __init__(self, mod, mod_extra_config, *args, **kwargs):
689
691
super ().__init__ ()
690
692
set_attrs_from_orig_model (self , mod , mod_extra_config )
691
693
if self .quantization_mode in [QuantMode .QUANTIZE , QuantMode .LOAD ]:
692
- self .orig_fetch_from_cache = types . MethodType ( mod .fetch_from_cache . __func__ , self )
694
+ self .orig_fetch_from_cache = mod .fetch_from_cache
693
695
self .quant_input = self ._mod_extra_config .inputs [0 ]
694
696
self .dequant_output = self ._mod_extra_config .outputs [0 ]
695
697
if self .use_qdq :
@@ -775,7 +777,7 @@ def forward_quant(self, input):
775
777
776
778
def forward_measure (self , input ):
777
779
measure_input ((input ,), observer = self ._mod_extra_config .inputs )
778
- output = self .forward_orig (input )
780
+ output = self .orig_mod (input )
779
781
measure_output ((output ,), self ._mod_extra_config .outputs )
780
782
return output
781
783
@@ -816,7 +818,7 @@ def forward_quant(self, x, dim=None, invAttnHead=None):
816
818
817
819
def forward_measure (self , x , dim = None , invAttnHead = None ):
818
820
measure_input ((x ,), observer = self ._mod_extra_config .inputs )
819
- output = self .forward_orig (x , dim )
821
+ output = self .orig_mod (x , dim )
820
822
measure_output ((output ,), self ._mod_extra_config .outputs )
821
823
return output
822
824
@@ -861,7 +863,7 @@ def forward_quant(self, input, scale: float = 1.0):
861
863
862
864
def forward_measure (self , input , scale : float = 1.0 ):
863
865
measure_input ((input ,), observer = self ._mod_extra_config .inputs )
864
- output = self .forward_orig (input , scale )
866
+ output = self .orig_mod (input , scale )
865
867
measure_output ((output ,), self ._mod_extra_config .outputs )
866
868
return output
867
869
@@ -916,7 +918,7 @@ def forward_quant(self, input, scale: float = 1.0):
916
918
917
919
def forward_measure (self , input , scale : float = 1.0 ):
918
920
measure_input ((input ,), observer = self ._mod_extra_config .inputs )
919
- output = self .forward_orig (input , scale )
921
+ output = self .orig_mod (input , scale )
920
922
measure_output ((output ,), self ._mod_extra_config .outputs )
921
923
return output
922
924
0 commit comments