diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index 2aceb334f..b2395b774 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -317,6 +317,8 @@ def cat(tensors, dim, out=None): def __neg__(self): neg_value = (-self.int(float_datatype=True) - self.zero_point) * self.scale + # In case the dtype of self.int is different from the one of the scale + neg_value = neg_value.type(self.scale.dtype) if self.signed: return QuantTensor( value=neg_value, @@ -447,6 +449,8 @@ def __truediv__(self, other): def __abs__(self): if self.signed: abs_value = (torch.abs(self.int(float_datatype=True)) - self.zero_point) * self.scale + # In case the dtype of self.int is different from the one of the scale + abs_value = abs_value.type(self.scale.dtype) return QuantTensor( value=abs_value, scale=self.scale, diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index ac084c57b..b4025319a 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -307,6 +307,7 @@ def main(): # Get the model from torchvision model = get_torchvision_model(args.model_name) + model = model.to(dtype) # Preprocess the model for quantization if args.target_backend == 'flexml': @@ -314,7 +315,7 @@ def main(): img_shape = model_config['center_crop_shape'] model = preprocess_for_flexml_quantize( model, - torch.ones(1, 3, img_shape, img_shape), + torch.ones(1, 3, img_shape, img_shape, dtype=dtype), equalize_iters=args.graph_eq_iterations, equalize_merge_bias=args.graph_eq_merge_bias, merge_bn=not args.calibrate_bn) @@ -330,7 +331,6 @@ def main(): if args.act_equalization is not None: print("Applying activation equalization:") apply_act_equalization(model, calib_loader, layerwise=args.act_equalization == 'layerwise') - model = model.to(dtype) # Define the quantized model quant_model = quantize_model(