diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 60d70049e..c55254d95 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -136,15 +136,20 @@ def prepare_for_export(self, module): for tm in module.tracked_module_list} # Get the first quant weight as representative quant_weight = module.tracked_module_list[0].quant_weight() + + # (B)float16 is not supported with standard Q/DQ ops, thus we store its original dtype + # and we cast it to float32. The original dtype is then restored during the forward pass + scale = quant_weight.scale + self.scale_dtype = scale.dtype + if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16: + scale = self.cast_fn(scale, torch.float32) + self.symbolic_kwargs['int_weights'] = int_weights self.symbolic_kwargs['bit_width'] = quant_weight.bit_width self.symbolic_kwargs['clip_symbolic_kwargs'] = self.int_clip_symbolic_kwargs( module.is_narrow_range, module.is_signed, quant_weight.bit_width) self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs( - quant_weight.scale, - quant_weight.zero_point, - quant_weight.bit_width, - module.is_signed) + scale, quant_weight.zero_point, quant_weight.bit_width, module.is_signed) else: self.symbolic_kwargs = None @@ -160,16 +165,13 @@ def symbolic_execution(self, x: Tensor): scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape') # Workaround to trick the tracer into believing all return values are used self.assert_ge_zero(scale, zero_point, bit_width) - scale_dtype = scale.dtype - if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: - scale = self.cast_fn(scale, torch.float32) - dequantize_symbolic_kwargs['scale'] = scale if clip_symbolic_kwargs is not None: x = self.clip_fn(x, *clip_symbolic_kwargs.values()) x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values()) - if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: - x = self.cast_fn(x, scale_dtype) - scale = self.cast_fn(scale, scale_dtype) + # After dequantization, cast both input and scale to the correct dtype + if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16: + x = self.cast_fn(x, self.scale_dtype) + scale = self.cast_fn(scale, self.scale_dtype) # Restore the original shapes to guarantee correct shape propagation downstream scale = scale.view(scale_orig_shape) zero_point = zero_point.view_as(scale) @@ -214,12 +216,21 @@ def prepare_for_export(self, module): if module.is_quant_enabled: self.validate(module) self.symbolic_kwargs['bit_width'] = module.bit_width() + + # (B)float16 is not supported with standard Q/DQ ops, thus we store its original dtype + # and we cast it to float32. The original dtype is then restored during the forward pass + scale = module.scale() + self.scale_dtype = scale.dtype + if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16: + scale = self.cast_fn(scale, torch.float32) + self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs( - module.scale(), module.zero_point(), module.bit_width(), module.is_signed) + scale, module.zero_point(), module.bit_width(), module.is_signed) self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs( - module.scale(), module.zero_point(), module.bit_width(), module.is_signed) + scale, module.zero_point(), module.bit_width(), module.is_signed) self.symbolic_kwargs['clip_symbolic_kwargs'] = self.int_clip_symbolic_kwargs( module.is_narrow_range, module.is_signed, module.bit_width()) + else: self.symbolic_kwargs = None @@ -235,19 +246,17 @@ def symbolic_execution(self, x: Tensor): bit_width = self.symbolic_kwargs['bit_width'] # Workaround to trick the tracer into believing all return values are used self.assert_ge_zero(scale, zero_point, bit_width) - scale_dtype = scale.dtype - if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: + # If original dtype of scale is (b)float16, cast the input + if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16: x = self.cast_fn(x, torch.float32) - scale = self.cast_fn(scale, torch.float32) - quantize_symbolic_kwargs['scale'] = scale - dequantize_symbolic_kwargs['scale'] = scale x = self.quantize_fn(x, *quantize_symbolic_kwargs.values()) if clip_symbolic_kwargs is not None: x = self.clip_fn(x, *clip_symbolic_kwargs.values()) x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values()) - if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: - x = self.cast_fn(x, scale_dtype) - scale = self.cast_fn(scale, scale_dtype) + # After dequantization, cast both output and scale to the correct dtype + if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16: + x = self.cast_fn(x, self.scale_dtype) + scale = self.cast_fn(scale, self.scale_dtype) # Restore the original shapes to guarantee correct shape propagation downstream scale = scale.view(scale_orig_shape) zero_point = zero_point.view_as(scale) @@ -298,10 +307,13 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None): zero_point = to_0dim_if_scalar(zero_point).expand_as(scale) zero_point = self.zero_point_with_dtype( True, bit_width, zero_point) # assume signed is True + # If original dtype of scale is (b)float16, store the original dtype + # and cast the scale to float32 scale_dtype = scale.dtype if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: scale = self.cast_fn(scale, torch.float32) y = self.dequantize_fn(int_bias, scale, zero_point, quant_axis) + # After dequantization, cast both output and scale to the correct dtype if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: y = self.cast_fn(y, scale_dtype) scale = self.cast_fn(scale, scale_dtype) @@ -331,6 +343,8 @@ def symbolic_execution( output_bit_width = self.symbolic_kwargs['output_bit_width'] dtype = self.int8_dtype() if signed else self.uint8_dtype() trunc_scale = 2.0 ** (input_bit_width - output_bit_width) + # If original dtype of scale is (b)float16, store the original dtype + # and cast the scale to float32 scale_dtype = scale.dtype if scale_dtype == torch.bfloat16 or scale_dtype == torch.float16: scale = self.cast_fn(scale, torch.float32) @@ -345,6 +359,7 @@ def symbolic_execution( if clip_symbolic_kwargs is not None: x = self.clip_fn(x, *clip_symbolic_kwargs.values()) x = self.dequantize_fn(x, flat_scale, zp, self.quant_axis(scale)) + # After dequantization, cast both output and scale to the correct dtype if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: x = self.cast_fn(x, scale_dtype) scale = self.cast_fn(scale, scale_dtype)