diff --git a/src/brevitas_examples/imagenet_classification/a2q/a2q_train_models.py b/src/brevitas_examples/imagenet_classification/a2q/a2q_train_models.py index 6d2cba60e..97a214c9a 100644 --- a/src/brevitas_examples/imagenet_classification/a2q/a2q_train_models.py +++ b/src/brevitas_examples/imagenet_classification/a2q/a2q_train_models.py @@ -91,7 +91,12 @@ default=False, help="If true, save torch model to specified save path.") parser.add_argument( - "--apply-bias-corr", + "--apply-act-calibration", + action="store_true", + default=False, + help="If true, apply activation calibration to the quantized model.") +parser.add_argument( + "--apply-bias-correction", action="store_true", default=False, help="If true, apply bias correction to the quantized model.") @@ -137,7 +142,7 @@ model = utils.get_model_by_name( args.model_name, init_from_float_checkpoint=args.from_float_checkpoint) criterion = nn.CrossEntropyLoss() - optimizer = optim.Adam( + optimizer = optim.SGD( utils.filter_params(model.named_parameters(), args.weight_decay), lr=args.lr_init, weight_decay=args.weight_decay) @@ -146,9 +151,14 @@ # Calibrate the quant model on the calibration dataset if args.apply_ep_init: print("Applying EP-init:") - utils.apply_ep_init(model, random_inp) + model = utils.apply_ep_init(model, random_inp) + + # Calibrate the quant model on the calibration dataset + if args.apply_act_calibration: + print("Applying activation calibration:") + utils.apply_act_calibrate(calibloader, model) - if args.apply_bias_corr: + if args.apply_bias_correction: print("Applying bias correction:") utils.apply_bias_correction(calibloader, model) @@ -173,7 +183,7 @@ model.load_state_dict(best_weights) top_1, top_5, loss = utils.evaluate_topk_accuracies(testloader, model, criterion) - print(f"Final top_1={top_1:.1%}, top_5={top_5:.1%}, loss={loss:.3f}") + print(f"Final: top_1={top_1:.1%}, top_5={top_5:.1%}, loss={loss:.3f}") # save checkpoint os.makedirs(args.save_path, exist_ok=True) diff --git a/src/brevitas_examples/imagenet_classification/a2q/ep_init.py b/src/brevitas_examples/imagenet_classification/a2q/ep_init.py index 564d0fae2..026f78af8 100644 --- a/src/brevitas_examples/imagenet_classification/a2q/ep_init.py +++ b/src/brevitas_examples/imagenet_classification/a2q/ep_init.py @@ -127,3 +127,5 @@ def register_upper_bound(module: AccumulatorAwareParameterPreScaling, inp, outpu for hook in hook_list: hook.remove() + + return model diff --git a/src/brevitas_examples/imagenet_classification/a2q/utils.py b/src/brevitas_examples/imagenet_classification/a2q/utils.py index c7352fdb2..eaa4c200d 100644 --- a/src/brevitas_examples/imagenet_classification/a2q/utils.py +++ b/src/brevitas_examples/imagenet_classification/a2q/utils.py @@ -20,6 +20,7 @@ from brevitas.core.scaling.pre_scaling import AccumulatorAwareParameterPreScaling from brevitas.function import abs_binary_sign_grad from brevitas.graph.calibrate import bias_correction_mode +from brevitas.graph.calibrate import calibration_mode from .ep_init import apply_ep_init from .quant import * @@ -28,11 +29,12 @@ __all__ = [ "apply_ep_init", + "apply_act_calibrate", + "apply_bias_correction", "get_model_by_name", "filter_params", "create_calibration_dataloader", "get_cifar10_dataloaders", - "apply_bias_correction", "train_for_epoch", "evaluate_topk_accuracies"] @@ -239,6 +241,18 @@ def get_cifar10_dataloaders( return trainloader, testloader +def apply_act_calibrate(calib_loader, model): + model.eval() + dtype = next(model.parameters()).dtype + device = next(model.parameters()).device + with torch.no_grad(): + with calibration_mode(model): + for images, _ in tqdm(calib_loader): + images = images.to(device) + images = images.to(dtype) + model(images) + + def apply_bias_correction(calib_loader, model: nn.Module): model.eval() dtype = next(model.parameters()).dtype