Skip to content

Commit 795b716

Browse files
Revert "[SW-205334][SW-187731] llama70b vLLM fix graph breaks with torch.com…" (#87)
This reverts commit 01a5734. Co-authored-by: Danny Semiat <[email protected]>
1 parent 75fd28f commit 795b716

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import torch
1616
import torch.nn as nn
17-
import types
1817

1918
from .quant_config import QuantMode, get_hqt_config, ScaleFormat
2019
from .._core.quant_dequant import QuantDequant as qdq
@@ -75,10 +74,13 @@ def set_attrs_from_orig_model(cls_instance, mod, mod_extra_config, *func_names):
7574
cls_instance.fake_quant = config.cfg["fake_quant"]
7675
cls_instance.use_qdq = config.cfg["use_qdq"]
7776
cls_instance.scale_format = config.cfg["scale_format"]
78-
cls_instance.forward_orig = types.MethodType(mod.forward.__func__, cls_instance)
77+
# store original module in order to invoke its functions during measurements.
78+
# this may be omitted of torch remove the related validation from dynamo. see SW-187731.
79+
cls_instance.__dict__["orig_mod"] = mod
80+
cls_instance.forward_orig = mod.forward
7981
if func_names is not None:
8082
for func in func_names:
81-
setattr(cls_instance, func, types.MethodType(getattr(mod, func).__func__, cls_instance))
83+
setattr(cls_instance, func, getattr(mod, func))
8284

8385

8486
def get_current_repr(cls_instance, *member_names):
@@ -143,7 +145,7 @@ def forward_qdq(self, input, other):
143145

144146
def forward_measure(self, input, other):
145147
measure_input((input, other), observer=self._mod_extra_config.inputs)
146-
output = self.forward_orig(input, other)
148+
output = self.orig_mod(input, other)
147149
measure_output((output,), self._mod_extra_config.outputs)
148150
return output
149151

@@ -212,7 +214,7 @@ def forward_quant(self, input):
212214

213215
def forward_measure(self, input):
214216
measure_input((input,), observer=self._mod_extra_config.inputs)
215-
output = self.forward_orig(input)
217+
output = self.orig_mod(input)
216218
measure_output((output,), self._mod_extra_config.outputs)
217219
return output
218220

@@ -520,7 +522,7 @@ def forward_qdq(self, input):
520522
output = torch.matmul(qinput, qweight)
521523

522524
if self.gather_output:
523-
output = self.collective_func(output)
525+
output = self.orig_mod.collective_func(output)
524526
return self.post_all_reduce(output)
525527

526528
def forward_quant(self, input):
@@ -532,15 +534,15 @@ def forward_quant(self, input):
532534
scale_other_inv=self.scale_weight)
533535
dqoutput = self.dequant_output(output)
534536
if self.gather_output:
535-
dqoutput = self.collective_func(dqoutput)
537+
dqoutput = self.orig_mod.collective_func(dqoutput)
536538
return self.post_all_reduce(dqoutput)
537539

538540
def forward_measure(self, input):
539541
measure_input((input,), observer=self._mod_extra_config.inputs)
540542
output = torch.matmul(input, self.weight.transpose(-1, -2))
541543
measure_output((output,), self._mod_extra_config.outputs)
542544
if self.gather_output:
543-
output = self.collective_func(output)
545+
output = self.orig_mod.collective_func(output)
544546
return self.post_all_reduce(output)
545547

546548
def post_all_reduce(self, output):
@@ -689,7 +691,7 @@ def __init__(self, mod, mod_extra_config, *args, **kwargs):
689691
super().__init__()
690692
set_attrs_from_orig_model(self, mod, mod_extra_config)
691693
if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
692-
self.orig_fetch_from_cache = types.MethodType(mod.fetch_from_cache.__func__, self)
694+
self.orig_fetch_from_cache = mod.fetch_from_cache
693695
self.quant_input = self._mod_extra_config.inputs[0]
694696
self.dequant_output = self._mod_extra_config.outputs[0]
695697
if self.use_qdq:
@@ -775,7 +777,7 @@ def forward_quant(self, input):
775777

776778
def forward_measure(self, input):
777779
measure_input((input,), observer=self._mod_extra_config.inputs)
778-
output = self.forward_orig(input)
780+
output = self.orig_mod(input)
779781
measure_output((output,), self._mod_extra_config.outputs)
780782
return output
781783

@@ -816,7 +818,7 @@ def forward_quant(self, x, dim=None, invAttnHead=None):
816818

817819
def forward_measure(self, x, dim=None, invAttnHead=None):
818820
measure_input((x,), observer=self._mod_extra_config.inputs)
819-
output = self.forward_orig(x, dim)
821+
output = self.orig_mod(x, dim)
820822
measure_output((output,), self._mod_extra_config.outputs)
821823
return output
822824

@@ -861,7 +863,7 @@ def forward_quant(self, input, scale: float = 1.0):
861863

862864
def forward_measure(self, input, scale: float = 1.0):
863865
measure_input((input,), observer=self._mod_extra_config.inputs)
864-
output = self.forward_orig(input, scale)
866+
output = self.orig_mod(input, scale)
865867
measure_output((output,), self._mod_extra_config.outputs)
866868
return output
867869

@@ -916,7 +918,7 @@ def forward_quant(self, input, scale: float = 1.0):
916918

917919
def forward_measure(self, input, scale: float = 1.0):
918920
measure_input((input,), observer=self._mod_extra_config.inputs)
919-
output = self.forward_orig(input, scale)
921+
output = self.orig_mod(input, scale)
920922
measure_output((output,), self._mod_extra_config.outputs)
921923
return output
922924

test/3x/torch/algorithms/fp8_quant/test_register_apis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def forward_measure(self, input):
231231
)
232232

233233
measure_input((input,), observer=self._mod_extra_config.inputs)
234-
output = self.forward_orig(input)
234+
output = self.orig_mod(input)
235235
measure_output((output,), self._mod_extra_config.outputs)
236236
return output
237237

0 commit comments

Comments
 (0)