Skip to content

Commit

Permalink
Adding descriptions to argparser
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Jan 31, 2024
1 parent 5c34f9c commit 270011c
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 54 deletions.
117 changes: 89 additions & 28 deletions src/brevitas_examples/imagenet_classification/a2q/a2q_train_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,93 @@

import brevitas.config as config
from brevitas.export import export_qonnx
from brevitas_examples.imagenet_classification.a2q.ep_init import apply_bias_correction
from brevitas_examples.imagenet_classification.a2q.ep_init import apply_ep_init
import brevitas_examples.imagenet_classification.a2q.utils as utils

parser = argparse.ArgumentParser()
parser.add_argument("--data-root", type=str, required=True)
parser.add_argument("--model-name", type=str, default="quant_resnet18_w4a4_a2q_32b")
parser.add_argument("--save-path", type=str, default='outputs/')
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--pin-memory", action="store_true", default=False)
parser.add_argument("--batch-size-train", type=int, default=256)
parser.add_argument("--batch-size-test", type=int, default=512)
parser.add_argument("--batch-size-calibration", type=int, default=256)
parser.add_argument("--calibration-samples", type=int, default=1000)
parser.add_argument("--weight-decay", type=float, default=1e-5)
parser.add_argument("--lr-init", type=float, default=1e-3)
parser.add_argument("--lr-step-size", type=int, default=30)
parser.add_argument("--lr-gamma", type=float, default=0.1)
parser.add_argument("--total-epochs", type=int, default=90)
parser.add_argument("--pretrained", action="store_true", default=False)
parser.add_argument("--save-ckpt", action="store_true", default=False)
parser.add_argument("--apply-bias-corr", action="store_true", default=False)
parser.add_argument("--apply-ep-init", action="store_true", default=False)
parser.add_argument("--export-to-qonnx", action="store_true", default=False)
parser.add_argument(
"--data-root", type=str, required=True, help="Directory where the dataset is stored.")
parser.add_argument(
"--model-name",
type=str,
default="quant_resnet18_w4a4_a2q_32b",
help="Name of model to train. Default: 'quant_resnet18_w4a4_a2q_32b'",
choices=utils.model_impl.keys())
parser.add_argument(
"--save-path",
type=str,
default="outputs/",
help="Directory where to save checkpoints. Default: 'outputs/'")
parser.add_argument(
"--num-workers",
type=int,
default=0,
help="Number of workers for the dataloader to use. Default: 0")
parser.add_argument(
"--pin-memory",
action="store_true",
default=False,
help="If true, pin memory for the dataloader.")
parser.add_argument(
"--batch-size-train",
type=int,
default=256,
help="Batch size for the training dataloader. Default: 256")
parser.add_argument(
"--batch-size-test",
type=int,
default=512,
help="Batch size for the testing dataloader. Default: 512")
parser.add_argument(
"--batch-size-calibration",
type=int,
default=256,
help="Batch size for the calibration dataloader. Default: 256")
parser.add_argument(
"--calibration-samples",
type=int,
default=1000,
help="Number of samples to use for calibration. Default: 1000")
parser.add_argument(
"--weight-decay",
type=float,
default=1e-5,
help="Weight decay for the Adam optimizer. Default: 0.00001")
parser.add_argument(
"--lr-init", type=float, default=1e-3, help="Initial learning rate. Default: 0.001")
parser.add_argument(
"--lr-step-size",
type=int,
default=30,
help="Step size for the learning rate scheduler. Default: 30")
parser.add_argument(
"--lr-gamma",
type=float,
default=0.1,
help="Default gamma for the learning rate scheduler. Default: 0.1")
parser.add_argument(
"--total-epochs", type=int, default=90, help="Total epoch to train the model for. Default: 90")
parser.add_argument(
"--from-float-checkpoint",
action="store_true",
default=False,
help="If true, use a pre-trained floating-point checkpoint.")
parser.add_argument(
"--save-torch-model",
action="store_true",
default=False,
help="If true, save torch model to specified save path.")
parser.add_argument(
"--apply-bias-corr",
action="store_true",
default=False,
help="If true, apply bias correction to the quantized model.")
parser.add_argument(
"--apply-ep-init",
action="store_true",
default=False,
help="If true, apply EP-init to the quantized model.")
parser.add_argument(
"--export-to-qonnx", action="store_true", default=False, help="If true, export model to QONNX.")

# ignore missing keys when loading pre-trained checkpoint
config.IGNORE_MISSING_KEYS = True
Expand Down Expand Up @@ -71,10 +134,8 @@
num_workers=args.num_workers,
subset_size=args.calibration_samples)

print(
f"Initializating {args.model_name} from",
"checkpoint..." if args.pretrained else "scratch...")
model = utils.get_model_by_name(args.model_name, args.pretrained)
model = utils.get_model_by_name(
args.model_name, init_from_float_checkpoint=args.from_float_checkpoint)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
utils.filter_params(model.named_parameters(), args.weight_decay),
Expand All @@ -85,11 +146,11 @@
# Calibrate the quant model on the calibration dataset
if args.apply_ep_init:
print("Applying EP-init:")
apply_ep_init(model, random_inp)
utils.apply_ep_init(model, random_inp)

if args.apply_bias_corr:
print("Applying bias correction:")
apply_bias_correction(calibloader, model)
utils.apply_bias_correction(calibloader, model)

best_top_1, best_weights = 0., copy.deepcopy(model.state_dict())
for epoch in range(args.total_epochs):
Expand All @@ -116,7 +177,7 @@

# save checkpoint
os.makedirs(args.save_path, exist_ok=True)
if args.save_ckpt:
if args.save_torch_model:
ckpt_path = f"{args.save_path}/{args.model_name}.pth"
torch.save(best_weights, ckpt_path)
with open(ckpt_path, "rb") as _file:
Expand Down
14 changes: 0 additions & 14 deletions src/brevitas_examples/imagenet_classification/a2q/ep_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,14 @@
import torch
from torch import Tensor
import torch.nn as nn
from tqdm import tqdm

from brevitas.core.scaling import AccumulatorAwareParameterPreScaling
from brevitas.function.shape import over_output_channels
from brevitas.graph.calibrate import bias_correction_mode
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL

__all__ = ["apply_bias_correction", "apply_ep_init"]


def apply_bias_correction(calib_loader, model: nn.Module):
model.eval()
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
with torch.no_grad():
with bias_correction_mode(model):
for (images, _) in tqdm(calib_loader):
images = images.to(device)
images = images.to(dtype)
model(images)


def get_a2q_module(module: nn.Module):
for submod in module.modules():
if isinstance(submod, AccumulatorAwareParameterPreScaling):
Expand Down
124 changes: 112 additions & 12 deletions src/brevitas_examples/imagenet_classification/a2q/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,54 @@

from brevitas.core.scaling.pre_scaling import AccumulatorAwareParameterPreScaling
from brevitas.function import abs_binary_sign_grad
from brevitas.graph.calibrate import bias_correction_mode

from .ep_init import apply_ep_init
from .quant import *
from .resnet import float_resnet18
from .resnet import quant_resnet18

__all__ = [
"apply_ep_init",
"get_model_by_name",
"filter_params",
"create_calibration_dataloader",
"get_cifar10_dataloaders",
"apply_bias_correction",
"train_for_epoch",
"evaluate_topk_accuracies"]

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_impl = {
"float_resnet18":
float_resnet18,
"quant_resnet18_w4a4_a2q_32b":
"quant_resnet18_w4a4_a2q_16b":
partial(
quant_resnet18,
act_bit_width=4,
acc_bit_width=32,
acc_bit_width=16,
weight_bit_width=4,
weight_quant=CommonIntAccumulatorAwareWeightQuant),
"quant_resnet18_w4a4_a2q_16b":
"quant_resnet18_w4a4_a2q_15b":
partial(
quant_resnet18,
act_bit_width=4,
acc_bit_width=16,
acc_bit_width=15,
weight_bit_width=4,
weight_quant=CommonIntAccumulatorAwareWeightQuant),
"quant_resnet18_w4a4_a2q_14b":
partial(
quant_resnet18,
act_bit_width=4,
acc_bit_width=14,
weight_bit_width=4,
weight_quant=CommonIntAccumulatorAwareWeightQuant),
"quant_resnet18_w4a4_a2q_13b":
partial(
quant_resnet18,
act_bit_width=4,
acc_bit_width=13,
weight_bit_width=4,
weight_quant=CommonIntAccumulatorAwareWeightQuant),
"quant_resnet18_w4a4_a2q_12b":
Expand All @@ -50,18 +76,32 @@
acc_bit_width=12,
weight_bit_width=4,
weight_quant=CommonIntAccumulatorAwareWeightQuant),
"quant_resnet18_w4a4_a2q_plus_32b":
"quant_resnet18_w4a4_a2q_plus_16b":
partial(
quant_resnet18,
act_bit_width=4,
acc_bit_width=32,
acc_bit_width=16,
weight_bit_width=4,
weight_quant=CommonIntAccumulatorAwareZeroCenterWeightQuant),
"quant_resnet18_w4a4_a2q_plus_16b":
"quant_resnet18_w4a4_a2q_plus_15b":
partial(
quant_resnet18,
act_bit_width=4,
acc_bit_width=16,
acc_bit_width=15,
weight_bit_width=4,
weight_quant=CommonIntAccumulatorAwareZeroCenterWeightQuant),
"quant_resnet18_w4a4_a2q_plus_14b":
partial(
quant_resnet18,
act_bit_width=4,
acc_bit_width=14,
weight_bit_width=4,
weight_quant=CommonIntAccumulatorAwareZeroCenterWeightQuant),
"quant_resnet18_w4a4_a2q_plus_13b":
partial(
quant_resnet18,
act_bit_width=4,
acc_bit_width=13,
weight_bit_width=4,
weight_quant=CommonIntAccumulatorAwareZeroCenterWeightQuant),
"quant_resnet18_w4a4_a2q_plus_12b":
Expand All @@ -74,16 +114,52 @@

root_url = 'https://github.com/Xilinx/brevitas/releases/download/'

model_url = {"float_resnet18": f"{root_url}/a2q/resnet18-e9872c01.pth"}
model_url = {
"float_resnet18":
f"{root_url}/ep_init/float_resnet18-1d98d23a.pth",
"quant_resnet18_w4a4_a2q_12b":
f"{root_url}/ep_init/quant_resnet18_w4a4_a2q_12b-8a440436.pth",
"quant_resnet18_w4a4_a2q_13b":
f"{root_url}/ep_init/quant_resnet18_w4a4_a2q_13b-8c31a2b1.pth",
"quant_resnet18_w4a4_a2q_14b":
f"{root_url}/ep_init/quant_resnet18_w4a4_a2q_14b-267f237b.pth",
"quant_resnet18_w4a4_a2q_15b":
f"{root_url}/ep_init/quant_resnet18_w4a4_a2q_15b-0d5bf266.pth",
"quant_resnet18_w4a4_a2q_16b":
f"{root_url}/ep_init/quant_resnet18_w4a4_a2q_16b-d0af41f1.pth",
"quant_resnet18_w4a4_a2q_plus_12b":
f"{root_url}/ep_init/quant_resnet18_w4a4_a2q_plus_12b-d69f003b.pth",
"quant_resnet18_w4a4_a2q_plus_13b":
f"{root_url}/ep_init/quant_resnet18_w4a4_a2q_plus_13b-332aaf81.pth",
"quant_resnet18_w4a4_a2q_plus_14b":
f"{root_url}/ep_init/quant_resnet18_w4a4_a2q_plus_14b-5a2d11aa.pth",
"quant_resnet18_w4a4_a2q_plus_15b":
f"{root_url}/ep_init/quant_resnet18_w4a4_a2q_plus_15b-3c89551a.pth",
"quant_resnet18_w4a4_a2q_plus_16b":
f"{root_url}/ep_init/quant_resnet18_w4a4_a2q_plus_16b-19973380.pth"}


def get_model_by_name(
model_name: str,
pretrained: bool = False,
init_from_float_checkpoint: bool = False) -> nn.Module:

def get_model_by_name(model_name: str, pretrained: bool) -> nn.Module:
assert model_name in model_impl, f"Error: {model_name} not implemented."
model: Module = model_impl[model_name]()
if pretrained:
checkpoint = model_url['float_resnet18']

if init_from_float_checkpoint:
checkpoint = model_url["float_resnet18"]
state_dict = hub.load_state_dict_from_url(checkpoint, progress=True, map_location='cpu')
model.load_state_dict(state_dict, strict=True)

elif pretrained:
checkpoint = model_url[model_name]
state_dict = hub.load_state_dict_from_url(checkpoint, progress=True, map_location='cpu')
if model_name.startswith("quant"):
# fixes issue when bias keys are missing in the pre-trained state_dict when loading from checkpoint
_prepare_bias_corrected_quant_model(model)
model.load_state_dict(state_dict, strict=True)

return model


Expand Down Expand Up @@ -163,6 +239,30 @@ def get_cifar10_dataloaders(
return trainloader, testloader


def apply_bias_correction(calib_loader, model: nn.Module):
model.eval()
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
with torch.no_grad():
with bias_correction_mode(model):
for (images, _) in tqdm(calib_loader):
images = images.to(device)
images = images.to(dtype)
model(images)


def _prepare_bias_corrected_quant_model(model: nn.Module):
model.eval()
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
images = torch.randn(10, 3, 32, 32)
images = images.to(device)
images = images.to(dtype)
with torch.no_grad():
with bias_correction_mode(model):
model(images)


def train_for_epoch(trainloader, model, criterion, optimizer, reg_weight: float = 1e-3):
model.train()
model = model.to(device)
Expand Down

0 comments on commit 270011c

Please sign in to comment.