Skip to content

Commit 0f6e6e0

Browse files
authored
[SW-204341] explicit scale format for ops (#73)
* [SW-204341] explicit scale format for ops Added wrapper around fp8 functions Wrapper decides which flavor of the function to call, according to scale format Helper modules call the wrapper Decide which cast flavor to call, according to scale format * [SW-204341] Adjust softmax API , remove commented-out code * [SW-204341] Fixes from CR 1 * [SW-204341] Fixed CR 2 * [SW-204341] add missing arg is fsdpa Signed-off-by: Uri Livne <[email protected]> * [SW-204341] Enhance SDPA for measure and quant * [SW-204341] remove sdpa quantized ops * reland per op class with more enchancments * [SW-204341] reland specfic arguments , rename class to wrapper * added call with self in patched lm head rebased on top of master next force push * fix mistake in conflict resolution resotore MethodType fix * antoher fix * modified fp8 mtamul test to test quantized matmul func * another fix of rebase mistake * hopefully last rebase mistake fix * restore backward compatibly import protection --------- Signed-off-by: Uri Livne <[email protected]>
1 parent ce86dc1 commit 0f6e6e0

File tree

4 files changed

+230
-154
lines changed

4 files changed

+230
-154
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/quant_dequant.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@ def qdq_init(self):
5858
self.quant_max = int(torch.finfo(self.lp_dtype).max)
5959
self.forward = self.forward_qdq
6060

61+
def set_cast_to_op(self):
62+
return torch.ops.hpu.cast_to_fp8_v2.scalar if self.scale_format == ScaleFormat.SCALAR else \
63+
torch.ops.hpu.cast_to_fp8_v2
64+
65+
def set_cast_from_op(self):
66+
return torch.ops.hpu.cast_from_fp8.scalar if self.scale_format == ScaleFormat.SCALAR else \
67+
torch.ops.hpu.cast_from_fp8
68+
6169
@abstractmethod
6270
def forward(self, *args, **kwargs):
6371
pass
@@ -95,8 +103,10 @@ def __init__(self, scale_inv, lp_dtype, hp_dtype, *args, **kwargs):
95103
quantize_per_channel_to_fp8 if self.scale.numel() > 1 else quantize_per_tensor_to_fp8
96104
)
97105

106+
self.cast_to_op = self.set_cast_to_op()
107+
98108
def forward(self, x):
99-
return cast_to_fp8_fcn(x, self.lp_dtype, self.scale_inv)
109+
return self.cast_to_op(x, self.scale_inv, False, False, self.lp_dtype)[0]
100110

101111
def forward_qdq(self, x):
102112
return self.quantize_op(
@@ -124,8 +134,10 @@ def __init__(self, scale, lp_dtype, hp_dtype, *args, **kwargs):
124134
dequantize_per_channel_from_fp8 if self.scale.numel() > 1 else dequantize_per_tensor_from_fp8
125135
)
126136

137+
self.cast_from_op = self.set_cast_from_op()
138+
127139
def forward(self, x):
128-
return cast_from_fp8_fcn(x, self.hp_dtype, self.scale)
140+
return self.cast_from_op(x, self.scale, self.hp_dtype)
129141

130142
def forward_qdq(self, x):
131143
return self.dequantize_op(
@@ -150,14 +162,16 @@ def __init__(self, scale_inv, lp_dtype, hp_dtype, *args, **kwargs):
150162
super(QuantDequant, self).__init__(lp_dtype, hp_dtype, *args, **kwargs)
151163
self.scale_inv = create_scale_tensor(scale_inv, self.scale_format)
152164
self.scale = create_scale_tensor(1 / scale_inv, self.scale_format)
165+
self.cast_to_op = self.set_cast_to_op()
166+
self.cast_from_op = self.set_cast_from_op()
153167

154168
def forward(self, x, *args, **kwargs):
155-
y = cast_to_fp8_fcn(x, self.lp_dtype, self.scale_inv)
169+
y = self.cast_to_op(x, self.scale_inv, False, False, self.lp_dtype)[0]
156170
# mark_step is needed so fuser won't remove 2 consecutive casts.
157171
# will be removed once SW-196431 is implemented
158172
# Call cur_accelerator.synchronize() which will call mark_step() as well
159173
cur_accelerator.synchronize()
160-
z = cast_from_fp8_fcn(y, self.hp_dtype, self.scale)
174+
z = self.cast_from_op(y, self.scale, self.hp_dtype)
161175
cur_accelerator.synchronize()
162176
return z
163177

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
2+
from .._quant_common.quant_config import ScaleFormat
3+
from ..utils.logger import logger
4+
5+
try: # backwards compatibility for 1.16
6+
from habana_frameworks.torch.hpex.kernels import fp8_fused_sdpa
7+
except ImportError:
8+
pass
9+
10+
import torch
11+
12+
from abc import ABC, abstractmethod
13+
from enum import Enum, auto
14+
15+
class OP_TYPE(Enum):
16+
# class per hpu custom fp8 ops used in patched modules logic
17+
GEMM = auto(),
18+
SOFTMAX = auto()
19+
CONV = auto()
20+
FSDPA = auto()
21+
22+
23+
class QuantizedHpuFuncWrapper(ABC):
24+
"""
25+
Base class for wrapping calls to hpu custom fp8 ops.
26+
The concrete class object is created in patched module init in call to get_hpu_quantized_func_wrapper.
27+
Concrete class should define get_default_quantized_func method.
28+
Concrete class may override base class methods in case custom op logic is unique, see examples in concrete
29+
classes below.
30+
"""
31+
def __init__(self, scale_format):
32+
self.set_quantized_func(scale_format)
33+
self.quantized_func_args = None
34+
35+
@abstractmethod
36+
def get_default_quantized_func(self):
37+
raise NotImplementedError()
38+
39+
def get_scalar_quantized_func(self):
40+
return self.get_default_quantized_func().scalar
41+
42+
def set_quantized_func(self, scale_format):
43+
if scale_format == ScaleFormat.SCALAR:
44+
self._quantized_func_ = self.get_scalar_quantized_func()
45+
elif scale_format == ScaleFormat.CONST:
46+
self._quantized_func_ = self.get_default_quantized_func()
47+
else:
48+
raise ValueError("Unexpected scale format - {}".format(scale_format))
49+
50+
def __call__(self, *args, **kwargs):
51+
return self._quantized_func_(*args, **kwargs)
52+
53+
class QuantizedHpuMatmul(QuantizedHpuFuncWrapper):
54+
55+
def get_default_quantized_func(self):
56+
return torch.ops.hpu.fp8_gemm_v2
57+
58+
# only specific arguments are defined, to avoid having all other arguments defined in each call in patched modules.
59+
def __call__(self, input, other, out=None, out_dtype=torch.bfloat16, scale_input_inv=None, scale_other_inv=None):
60+
return self._quantized_func_(input,
61+
False,
62+
other,
63+
False,
64+
out,
65+
out_dtype,
66+
scale_input_inv,
67+
scale_other_inv,
68+
None,
69+
False)
70+
71+
class QuantizedHpuConv(QuantizedHpuFuncWrapper):
72+
73+
def get_default_quantized_func(self):
74+
return torch.ops.hpu.conv2d_fp8
75+
76+
@staticmethod
77+
def to_list_if_necessary(param):
78+
return param if hasattr(param, "__iter__") else [param] * 2
79+
80+
# only specific arguments are defined, to avoid having all other arguments defined in each call in patched modules.
81+
def __call__(self,
82+
input,
83+
weight,
84+
bias,
85+
stride,
86+
padding,
87+
dilation,
88+
groups,
89+
out_dtype=torch.bfloat16,
90+
scale_input_inv=None,
91+
scale_other_inv=None):
92+
93+
return self._quantized_func_(input=input,
94+
weight=weight,
95+
bias=bias,
96+
stride=self.to_list_if_necessary(stride),
97+
padding=self.to_list_if_necessary(padding),
98+
dilation=self.to_list_if_necessary(dilation),
99+
groups=groups,
100+
out_dtype=out_dtype,
101+
scale_input=scale_input_inv,
102+
scale_weight=scale_other_inv)
103+
104+
class QuantizedHpuSoftmax(QuantizedHpuFuncWrapper):
105+
106+
def get_default_quantized_func(self):
107+
return torch.ops.hpu.softmax_fp8
108+
109+
def get_scalar_quantized_func(self):
110+
# softmax custom op has different scalar impl name
111+
return self.get_default_quantized_func().Scalar_scales
112+
113+
class QuantizedHpuFSDPA(QuantizedHpuFuncWrapper):
114+
115+
def __init__(self, scale_format):
116+
# FSDPA isn't optimized for scalar flavor due to complexity of specific torch op api selection
117+
self._quantized_func_ = self.get_default_quantized_func()
118+
119+
def get_default_quantized_func(self):
120+
return fp8_fused_sdpa
121+
122+
def get_scalar_quantized_func(self):
123+
raise NotImplementedError()
124+
125+
_OP_TYPE_HPU_QUANTIZED_WRAPPER_CLASSES = {OP_TYPE.GEMM : QuantizedHpuMatmul,
126+
OP_TYPE.SOFTMAX : QuantizedHpuSoftmax,
127+
OP_TYPE.CONV : QuantizedHpuConv,
128+
OP_TYPE.FSDPA : QuantizedHpuFSDPA
129+
}
130+
131+
def get_hpu_quantized_func_wrapper(op_type, scale_format):
132+
quantized_hpu_wrapper_class = _OP_TYPE_HPU_QUANTIZED_WRAPPER_CLASSES[op_type]
133+
return quantized_hpu_wrapper_class(scale_format)

0 commit comments

Comments
 (0)