diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py
index 60d70049e..c55254d95 100644
--- a/src/brevitas/export/common/handler/qcdq.py
+++ b/src/brevitas/export/common/handler/qcdq.py
@@ -136,15 +136,20 @@ def prepare_for_export(self, module):
                 for tm in module.tracked_module_list}
             # Get the first quant weight as representative
             quant_weight = module.tracked_module_list[0].quant_weight()
+
+            # (B)float16 is not supported with standard Q/DQ ops, thus we store its original dtype
+            # and we cast it to float32. The original dtype is then restored during the forward pass
+            scale = quant_weight.scale
+            self.scale_dtype = scale.dtype
+            if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16:
+                scale = self.cast_fn(scale, torch.float32)
+
             self.symbolic_kwargs['int_weights'] = int_weights
             self.symbolic_kwargs['bit_width'] = quant_weight.bit_width
             self.symbolic_kwargs['clip_symbolic_kwargs'] = self.int_clip_symbolic_kwargs(
                 module.is_narrow_range, module.is_signed, quant_weight.bit_width)
             self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs(
-                quant_weight.scale,
-                quant_weight.zero_point,
-                quant_weight.bit_width,
-                module.is_signed)
+                scale, quant_weight.zero_point, quant_weight.bit_width, module.is_signed)
         else:
             self.symbolic_kwargs = None
 
@@ -160,16 +165,13 @@ def symbolic_execution(self, x: Tensor):
         scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape')
         # Workaround to trick the tracer into believing all return values are used
         self.assert_ge_zero(scale, zero_point, bit_width)
-        scale_dtype = scale.dtype
-        if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
-            scale = self.cast_fn(scale, torch.float32)
-            dequantize_symbolic_kwargs['scale'] = scale
         if clip_symbolic_kwargs is not None:
             x = self.clip_fn(x, *clip_symbolic_kwargs.values())
         x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
-        if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
-            x = self.cast_fn(x, scale_dtype)
-            scale = self.cast_fn(scale, scale_dtype)
+        # After dequantization, cast both input and scale to the correct dtype
+        if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
+            x = self.cast_fn(x, self.scale_dtype)
+            scale = self.cast_fn(scale, self.scale_dtype)
         # Restore the original shapes to guarantee correct shape propagation downstream
         scale = scale.view(scale_orig_shape)
         zero_point = zero_point.view_as(scale)
@@ -214,12 +216,21 @@ def prepare_for_export(self, module):
         if module.is_quant_enabled:
             self.validate(module)
             self.symbolic_kwargs['bit_width'] = module.bit_width()
+
+            # (B)float16 is not supported with standard Q/DQ ops, thus we store its original dtype
+            # and we cast it to float32. The original dtype is then restored during the forward pass
+            scale = module.scale()
+            self.scale_dtype = scale.dtype
+            if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16:
+                scale = self.cast_fn(scale, torch.float32)
+
             self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs(
-                module.scale(), module.zero_point(), module.bit_width(), module.is_signed)
+                scale, module.zero_point(), module.bit_width(), module.is_signed)
             self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs(
-                module.scale(), module.zero_point(), module.bit_width(), module.is_signed)
+                scale, module.zero_point(), module.bit_width(), module.is_signed)
             self.symbolic_kwargs['clip_symbolic_kwargs'] = self.int_clip_symbolic_kwargs(
                 module.is_narrow_range, module.is_signed, module.bit_width())
+
         else:
             self.symbolic_kwargs = None
 
@@ -235,19 +246,17 @@ def symbolic_execution(self, x: Tensor):
         bit_width = self.symbolic_kwargs['bit_width']
         # Workaround to trick the tracer into believing all return values are used
         self.assert_ge_zero(scale, zero_point, bit_width)
-        scale_dtype = scale.dtype
-        if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
+        # If original dtype of scale is (b)float16, cast the input
+        if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
             x = self.cast_fn(x, torch.float32)
-            scale = self.cast_fn(scale, torch.float32)
-            quantize_symbolic_kwargs['scale'] = scale
-            dequantize_symbolic_kwargs['scale'] = scale
         x = self.quantize_fn(x, *quantize_symbolic_kwargs.values())
         if clip_symbolic_kwargs is not None:
             x = self.clip_fn(x, *clip_symbolic_kwargs.values())
         x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
-        if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
-            x = self.cast_fn(x, scale_dtype)
-            scale = self.cast_fn(scale, scale_dtype)
+        # After dequantization, cast both output and scale to the correct dtype
+        if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
+            x = self.cast_fn(x, self.scale_dtype)
+            scale = self.cast_fn(scale, self.scale_dtype)
         # Restore the original shapes to guarantee correct shape propagation downstream
         scale = scale.view(scale_orig_shape)
         zero_point = zero_point.view_as(scale)
@@ -298,10 +307,13 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None):
         zero_point = to_0dim_if_scalar(zero_point).expand_as(scale)
         zero_point = self.zero_point_with_dtype(
             True, bit_width, zero_point)  # assume signed is True
+        # If original dtype of scale is (b)float16, store the original dtype
+        # and cast the scale to float32
         scale_dtype = scale.dtype
         if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
             scale = self.cast_fn(scale, torch.float32)
         y = self.dequantize_fn(int_bias, scale, zero_point, quant_axis)
+        # After dequantization, cast both output and scale to the correct dtype
         if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
             y = self.cast_fn(y, scale_dtype)
             scale = self.cast_fn(scale, scale_dtype)
@@ -331,6 +343,8 @@ def symbolic_execution(
         output_bit_width = self.symbolic_kwargs['output_bit_width']
         dtype = self.int8_dtype() if signed else self.uint8_dtype()
         trunc_scale = 2.0 ** (input_bit_width - output_bit_width)
+        # If original dtype of scale is (b)float16, store the original dtype
+        # and cast the scale to float32
         scale_dtype = scale.dtype
         if scale_dtype == torch.bfloat16 or scale_dtype == torch.float16:
             scale = self.cast_fn(scale, torch.float32)
@@ -345,6 +359,7 @@ def symbolic_execution(
         if clip_symbolic_kwargs is not None:
             x = self.clip_fn(x, *clip_symbolic_kwargs.values())
         x = self.dequantize_fn(x, flat_scale, zp, self.quant_axis(scale))
+        # After dequantization, cast both output and scale to the correct dtype
         if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
             x = self.cast_fn(x, scale_dtype)
             scale = self.cast_fn(scale, scale_dtype)