From e3f98d481082be6bc7548e2f3a7c5ac7d99cf73b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 21 Dec 2023 23:05:37 +0000 Subject: [PATCH] Feat (ptq/evaluate): support for bfloat16 --- .../imagenet_classification/ptq/ptq_evaluate.py | 13 +++++++++---- .../imagenet_classification/utils.py | 4 ++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 454bf2488..c821dad33 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -72,6 +72,8 @@ metavar='ARCH', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)') +parser.add_argument( + '--dtype', default='float', choices=['float', 'bfloat16'], help='Data type to use') parser.add_argument( '--target-backend', default='fx', @@ -215,6 +217,7 @@ default=None, type=int, help='Accumulator Bit Width for GPFA2Q (default: None)') +parser.add_argument('--onnx-opset-version', default=None, type=int, help='ONNX opset version') add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)') add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)') add_bool_arg(parser, 'gpfa2q', default=False, help='GPFA2Q (default: disabled)') @@ -226,11 +229,11 @@ def main(): args = parser.parse_args() + dtype = getattr(torch, args.dtype) random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) - if args.act_quant_calibration_type == 'stats': act_quant_calib_config = str(args.act_quant_percentile) + 'stats' else: @@ -312,6 +315,7 @@ def main(): # Get the model from torchvision model = get_torchvision_model(args.model_name) + model = model.to(dtype) # Preprocess the model for quantization if args.target_backend == 'flexml': @@ -319,7 +323,7 @@ def main(): img_shape = model_config['center_crop_shape'] model = preprocess_for_flexml_quantize( model, - torch.ones(1, 3, img_shape, img_shape), + torch.ones(1, 3, img_shape, img_shape, dtype=dtype), equalize_iters=args.graph_eq_iterations, equalize_merge_bias=args.graph_eq_merge_bias, merge_bn=not args.calibrate_bn) @@ -339,6 +343,7 @@ def main(): # Define the quantized model quant_model = quantize_model( model, + dtype=dtype, backend=args.target_backend, scale_factor_type=args.scale_factor_type, bias_bit_width=args.bias_bit_width, @@ -405,7 +410,7 @@ def main(): # Validate the quant_model on the validation dataloader print("Starting validation:") - validate(val_loader, quant_model) + validate(val_loader, quant_model, stable=dtype != torch.bfloat16) if args.export_onnx_qcdq or args.export_torch_qcdq: # Generate reference input tensor to drive the export process @@ -418,7 +423,7 @@ def main(): export_name = os.path.join(args.export_dir, config) if args.export_onnx_qcdq: export_name = export_name + '.onnx' - export_onnx_qcdq(model, ref_input, export_name) + export_onnx_qcdq(model, ref_input, export_name, opset_version=args.onnx_opset_version) if args.export_torch_qcdq: export_name = export_name + '.pt' export_torch_qcdq(model, ref_input, export_name) diff --git a/src/brevitas_examples/imagenet_classification/utils.py b/src/brevitas_examples/imagenet_classification/utils.py index 033058219..f614f287c 100644 --- a/src/brevitas_examples/imagenet_classification/utils.py +++ b/src/brevitas_examples/imagenet_classification/utils.py @@ -61,7 +61,7 @@ def accuracy(output, target, topk=(1,), stable=False): return res -def validate(val_loader, model): +def validate(val_loader, model, stable=True): """ Run validation on the desired dataset """ @@ -82,7 +82,7 @@ def print_accuracy(top1, prefix=''): output = model(images) # measure accuracy - acc1, = accuracy(output, target, stable=True) + acc1, = accuracy(output, target, stable=stable) top1.update(acc1[0], images.size(0)) print_accuracy(top1, 'Total:')