Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 11, 2023
1 parent 97d9808 commit dbaba42
Showing 1 changed file with 36 additions and 21 deletions.
57 changes: 36 additions & 21 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,20 @@ def prepare_for_export(self, module):
for tm in module.tracked_module_list}
# Get the first quant weight as representative
quant_weight = module.tracked_module_list[0].quant_weight()

# (B)float16 is not supported with standard Q/DQ ops, thus we store its original dtype
# and we cast it to float32. The original dtype is then restored during the forward pass
scale = quant_weight.scale
self.scale_dtype = scale.dtype
if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)

self.symbolic_kwargs['int_weights'] = int_weights
self.symbolic_kwargs['bit_width'] = quant_weight.bit_width
self.symbolic_kwargs['clip_symbolic_kwargs'] = self.int_clip_symbolic_kwargs(
module.is_narrow_range, module.is_signed, quant_weight.bit_width)
self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs(
quant_weight.scale,
quant_weight.zero_point,
quant_weight.bit_width,
module.is_signed)
scale, quant_weight.zero_point, quant_weight.bit_width, module.is_signed)
else:
self.symbolic_kwargs = None

Expand All @@ -160,16 +165,13 @@ def symbolic_execution(self, x: Tensor):
scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape')
# Workaround to trick the tracer into believing all return values are used
self.assert_ge_zero(scale, zero_point, bit_width)
scale_dtype = scale.dtype
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
scale = self.cast_fn(scale, torch.float32)
dequantize_symbolic_kwargs['scale'] = scale
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
x = self.cast_fn(x, scale_dtype)
scale = self.cast_fn(scale, scale_dtype)
# After dequantization, cast both input and scale to the correct dtype
if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
x = self.cast_fn(x, self.scale_dtype)
scale = self.cast_fn(scale, self.scale_dtype)
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
Expand Down Expand Up @@ -214,12 +216,21 @@ def prepare_for_export(self, module):
if module.is_quant_enabled:
self.validate(module)
self.symbolic_kwargs['bit_width'] = module.bit_width()

# (B)float16 is not supported with standard Q/DQ ops, thus we store its original dtype
# and we cast it to float32. The original dtype is then restored during the forward pass
scale = module.scale()
self.scale_dtype = scale.dtype
if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)

self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs(
module.scale(), module.zero_point(), module.bit_width(), module.is_signed)
scale, module.zero_point(), module.bit_width(), module.is_signed)
self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs(
module.scale(), module.zero_point(), module.bit_width(), module.is_signed)
scale, module.zero_point(), module.bit_width(), module.is_signed)
self.symbolic_kwargs['clip_symbolic_kwargs'] = self.int_clip_symbolic_kwargs(
module.is_narrow_range, module.is_signed, module.bit_width())

else:
self.symbolic_kwargs = None

Expand All @@ -235,19 +246,17 @@ def symbolic_execution(self, x: Tensor):
bit_width = self.symbolic_kwargs['bit_width']
# Workaround to trick the tracer into believing all return values are used
self.assert_ge_zero(scale, zero_point, bit_width)
scale_dtype = scale.dtype
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
# If original dtype of scale is (b)float16, cast the input
if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
x = self.cast_fn(x, torch.float32)
scale = self.cast_fn(scale, torch.float32)
quantize_symbolic_kwargs['scale'] = scale
dequantize_symbolic_kwargs['scale'] = scale
x = self.quantize_fn(x, *quantize_symbolic_kwargs.values())
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
x = self.cast_fn(x, scale_dtype)
scale = self.cast_fn(scale, scale_dtype)
# After dequantization, cast both output and scale to the correct dtype
if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
x = self.cast_fn(x, self.scale_dtype)
scale = self.cast_fn(scale, self.scale_dtype)
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
Expand Down Expand Up @@ -298,10 +307,13 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None):
zero_point = to_0dim_if_scalar(zero_point).expand_as(scale)
zero_point = self.zero_point_with_dtype(
True, bit_width, zero_point) # assume signed is True
# If original dtype of scale is (b)float16, store the original dtype
# and cast the scale to float32
scale_dtype = scale.dtype
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
scale = self.cast_fn(scale, torch.float32)
y = self.dequantize_fn(int_bias, scale, zero_point, quant_axis)
# After dequantization, cast both output and scale to the correct dtype
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
y = self.cast_fn(y, scale_dtype)
scale = self.cast_fn(scale, scale_dtype)
Expand Down Expand Up @@ -331,6 +343,8 @@ def symbolic_execution(
output_bit_width = self.symbolic_kwargs['output_bit_width']
dtype = self.int8_dtype() if signed else self.uint8_dtype()
trunc_scale = 2.0 ** (input_bit_width - output_bit_width)
# If original dtype of scale is (b)float16, store the original dtype
# and cast the scale to float32
scale_dtype = scale.dtype
if scale_dtype == torch.bfloat16 or scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)
Expand All @@ -345,6 +359,7 @@ def symbolic_execution(
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, flat_scale, zp, self.quant_axis(scale))
# After dequantization, cast both output and scale to the correct dtype
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
x = self.cast_fn(x, scale_dtype)
scale = self.cast_fn(scale, scale_dtype)
Expand Down

0 comments on commit dbaba42

Please sign in to comment.