diff --git a/src/brevitas_examples/super_resolution/README.md b/src/brevitas_examples/super_resolution/README.md index 2e762393e..9e0068c27 100644 --- a/src/brevitas_examples/super_resolution/README.md +++ b/src/brevitas_examples/super_resolution/README.md @@ -1,22 +1,29 @@ # Integer-Quantized Super Resolution Experiments with Brevitas -This directory contains training scripts to demonstrate how to train integer-quantized super resolution models using [Brevitas](https://github.com/Xilinx/brevitas). +This directory contains scripts demonstrating how to train integer-quantized super resolution models using [Brevitas](https://github.com/Xilinx/brevitas). Code is also provided to demonstrate accumulator-aware quantization (A2Q) as proposed in our ICCV 2023 paper "[A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance](https://arxiv.org/abs/2308.13504)". ## Experiments All models are trained on the BSD300 dataset to upsample images by 2x. -Target images are center cropped to 512x512. +Target images are cropped to 512x512. +During training random cropping is applied, along with random vertical and horizontal flips. +During inference center cropping is applied. Inputs are then downscaled by 2x and then used to train the model directly in the RGB space. Note that this is a difference from many academic works that train only on the Y-channel in YCbCr format. | Model Name | Upscale Factor | Weight quantization | Activation quantization | Peak Signal-to-Noise Ratio | |-----------------------------|----------------|---------------------|-------------------------|----------------------------| -| [float_espcn_x2](https://github.com/Xilinx/brevitas/releases/download/super_res_r0/float_espcn_x2-2f3821e3.pth) | x2 | float32 | float32 | 30.37 | -| [quant_espcn_x2_w8a8_base](https://github.com/Xilinx/brevitas/releases/download/super_res_r0/quant_espcn_x2_w8a8_base-7d54e29c.pth) | x2 | int8 | (u)int8 | 30.16 | -| [quant_espcn_x2_w8a8_a2q_32b](https://github.com/Xilinx/brevitas/releases/download/super_res_r0/quant_espcn_x2_w8a8_a2q_32b-0b1f361d.pth) | x2 | int8 | (u)int8 | 30.80 | -| [quant_espcn_x2_w8a8_a2q_16b](https://github.com/Xilinx/brevitas/releases/download/super_res_r0/quant_espcn_x2_w8a8_a2q_16b-3c4acd35.pth) | x2 | int8 | (u)int8 | 29.38 | | bicubic_interp | x2 | N/A | N/A | 28.71 | +| [float_espcn_x2]() | x2 | float32 | float32 | 31.03 | +|| +| [quant_espcn_x2_w8a8_base]() | x2 | int8 | (u)int8 | 30.96 | +| [quant_espcn_x2_w8a8_a2q_32b]() | x2 | int8 | (u)int8 | 30.79 | +| [quant_espcn_x2_w8a8_a2q_16b]() | x2 | int8 | (u)int8 | 30.56 | +|| +| [quant_espcn_x2_w4a4_base]() | x2 | int4 | (u)int4 | 30.30 | +| [quant_espcn_x2_w4a4_a2q_32b]() | x2 | int4 | (u)int4 | 30.27 | +| [quant_espcn_x2_w4a4_a2q_13b]() | x2 | int4 | (u)int4 | 30.24 | ## Train diff --git a/src/brevitas_examples/super_resolution/eval_model.py b/src/brevitas_examples/super_resolution/eval_model.py index a567469d1..60883dea6 100644 --- a/src/brevitas_examples/super_resolution/eval_model.py +++ b/src/brevitas_examples/super_resolution/eval_model.py @@ -31,13 +31,22 @@ parser = argparse.ArgumentParser(description='PyTorch BSD300 Validation') parser.add_argument('--data_root', help='Path to folder containing BSD300 val folder') -parser.add_argument('--model_path', default=None, help='Path to PyTorch checkpoint') +parser.add_argument('--model_path', default=None, help='Path to PyTorch checkpoint. Default = None') parser.add_argument( - '--save_path', type=str, default='outputs/', help='Save path for exported model') + '--save_path', + type=str, + default='outputs/', + help='Save path for exported model. Default = outputs/') parser.add_argument( - '--model', type=str, default='quant_espcn_x2_w8a8_base', help='Name of the model configuration') -parser.add_argument('--workers', type=int, default=0, help='Number of data loading workers') -parser.add_argument('--batch_size', type=int, default=16, help='Minibatch size') + '--model', + type=str, + default='quant_espcn_x2_w8a8_base', + help='Name of the model configuration. Default = quant_espcn_x2_w8a8_base') +parser.add_argument( + '--workers', type=int, default=0, help='Number of data loading workers. Default = 0') +parser.add_argument('--batch_size', type=int, default=16, help='Minibatch size. Default = 16') +parser.add_argument( + '--crop_size', type=int, default=512, help='The size to crop the image. Default = 512') parser.add_argument('--use_pretrained', action='store_true', default=False) parser.add_argument('--eval_acc_bw', action='store_true', default=False) parser.add_argument('--save_model_io', action='store_true', default=False) @@ -60,6 +69,7 @@ def main(): num_workers=args.workers, batch_size=args.batch_size, upscale_factor=model.upscale_factor, + crop_size=args.crop_size, download=True) test_psnr = evaluate_avg_psnr(testloader, model) diff --git a/src/brevitas_examples/super_resolution/models/__init__.py b/src/brevitas_examples/super_resolution/models/__init__.py index fa5bff7bd..872624a85 100644 --- a/src/brevitas_examples/super_resolution/models/__init__.py +++ b/src/brevitas_examples/super_resolution/models/__init__.py @@ -5,6 +5,7 @@ from typing import Union from torch import hub +import torch.nn as nn from .espcn import * @@ -26,7 +27,23 @@ upscale_factor=2, weight_bit_width=8, act_bit_width=8, - acc_bit_width=16)} + acc_bit_width=16), + 'quant_espcn_x2_w4a4_base': + partial(quant_espcn_base, upscale_factor=2, weight_bit_width=4, act_bit_width=4), + 'quant_espcn_x2_w4a4_a2q_32b': + partial( + quant_espcn_a2q, + upscale_factor=2, + weight_bit_width=4, + act_bit_width=4, + acc_bit_width=32), + 'quant_espcn_x2_w4a4_a2q_13b': + partial( + quant_espcn_a2q, + upscale_factor=2, + weight_bit_width=4, + act_bit_width=4, + acc_bit_width=13)} root_url = 'https://github.com/Xilinx/brevitas/releases/download/super_res-r0' @@ -40,8 +57,10 @@ def get_model_by_name(name: str, pretrained: bool = False) -> Union[FloatESPCN, QuantESPCN]: if name not in model_impl.keys(): raise NotImplementedError(f"{name} does not exist.") - model = model_impl[name]() + model: nn.Module = model_impl[name]() if pretrained: + if name not in model_impl: + raise NotImplementedError(f"Error: {name} does not have a pre-trained checkpoint.") checkpoint = model_url[name] state_dict = hub.load_state_dict_from_url(checkpoint, progress=True, map_location='cpu') model.load_state_dict(state_dict, strict=True) diff --git a/src/brevitas_examples/super_resolution/models/common.py b/src/brevitas_examples/super_resolution/models/common.py index 9d03aba51..d3022d089 100644 --- a/src/brevitas_examples/super_resolution/models/common.py +++ b/src/brevitas_examples/super_resolution/models/common.py @@ -25,8 +25,8 @@ class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat): class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant): - pre_scaling_min_val = 1e-8 - scaling_min_val = 1e-8 + pre_scaling_min_val = 1e-10 + scaling_min_val = 1e-10 class CommonIntActQuant(Int8ActPerTensorFloat): diff --git a/src/brevitas_examples/super_resolution/models/espcn.py b/src/brevitas_examples/super_resolution/models/espcn.py index 3af2846f8..f3123fcb9 100644 --- a/src/brevitas_examples/super_resolution/models/espcn.py +++ b/src/brevitas_examples/super_resolution/models/espcn.py @@ -12,7 +12,6 @@ from .common import CommonIntWeightPerChannelQuant from .common import CommonUintActQuant from .common import ConstUint8ActQuant -from .common import QuantNearestNeighborConvolution __all__ = [ "float_espcn", "quant_espcn", "quant_espcn_a2q", "quant_espcn_base", "FloatESPCN", "QuantESPCN"] @@ -29,9 +28,7 @@ def weight_init(layer): class FloatESPCN(nn.Module): - """Floating-point version of FINN-Friendly Quantized Efficient Sub-Pixel Convolution - Network (ESPCN) as used in Colbert et al. (2023) - `Quantized Neural Networks for - Low-Precision Accumulation with Guaranteed Overflow Avoidance`.""" + """Floating-point version of Efficient Sub-Pixel Convolution Network (ESPCN)""" def __init__(self, upscale_factor: int = 3, num_channels: int = 3): super(FloatESPCN, self).__init__() @@ -48,17 +45,14 @@ def __init__(self, upscale_factor: int = 3, num_channels: int = 3): in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True) self.conv3 = nn.Conv2d( in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, bias=True) - self.conv4 = nn.Sequential() - self.conv4.add_module("interp", nn.UpsamplingNearest2d(scale_factor=upscale_factor)) - self.conv4.add_module( - "conv", - nn.Conv2d( - in_channels=32, - out_channels=num_channels, - kernel_size=3, - stride=1, - padding=1, - bias=True)) + self.conv4 = nn.Conv2d( + in_channels=32, + out_channels=num_channels * pow(upscale_factor, 2), + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.pixel_shuffle = nn.PixelShuffle(upscale_factor) self.bn1 = nn.BatchNorm2d(64) self.bn2 = nn.BatchNorm2d(64) @@ -75,15 +69,13 @@ def forward(self, inp: Tensor): x = self.relu(self.bn1(self.conv1(x))) x = self.relu(self.bn2(self.conv2(x))) x = self.relu(self.bn3(self.conv3(x))) - x = self.conv4(x) - x = self.out(x) + x = self.pixel_shuffle(self.conv4(x)) + x = self.out(x) # To mirror quant version return x class QuantESPCN(FloatESPCN): - """FINN-Friendly Quantized Efficient Sub-Pixel Convolution Network (ESPCN) as - used in Colbert et al. (2023) - `Quantized Neural Networks for Low-Precision - Accumulation with Guaranteed Overflow Avoidance`.""" + """FINN-Friendly Quantized Efficient Sub-Pixel Convolution Network (ESPCN)""" def __init__( self, @@ -130,27 +122,28 @@ def __init__( kernel_size=3, stride=1, padding=1, + bias=True, input_bit_width=act_bit_width, input_quant=CommonUintActQuant, weight_bit_width=weight_bit_width, weight_accumulator_bit_width=acc_bit_width, weight_quant=weight_quant) - # Quantizing the weights and input activations to 8-bit integers - # and not applying accumulator constraint to the final convolution - # layer (i.e., accumulator_bit_width=32). - self.conv4 = QuantNearestNeighborConvolution( + # We quantize the weights and input activations of the final layer + # to 8-bit integers. We do not apply the accumulator constraint to + # the final convolution layer. FINN does not currently support + # per-tensor quantization or biases for sub-pixel convolution layers. + self.conv4 = qnn.QuantConv2d( in_channels=32, - out_channels=num_channels, + out_channels=num_channels * pow(upscale_factor, 2), kernel_size=3, stride=1, padding=1, - upscale_factor=upscale_factor, - bias=True, - signed_act=False, - act_bit_width=IO_DATA_BIT_WIDTH, - acc_bit_width=IO_ACC_BIT_WIDTH, - weight_quant=weight_quant, - weight_bit_width=IO_DATA_BIT_WIDTH) + bias=False, + input_bit_width=act_bit_width, + input_quant=CommonUintActQuant, + weight_bit_width=IO_DATA_BIT_WIDTH, + weight_quant=CommonIntWeightPerChannelQuant, + weight_scaling_per_output_channel=False) self.bn1 = nn.BatchNorm2d(64) self.bn2 = nn.BatchNorm2d(64) diff --git a/src/brevitas_examples/super_resolution/train_model.py b/src/brevitas_examples/super_resolution/train_model.py index 52dfe8c72..f446355cb 100644 --- a/src/brevitas_examples/super_resolution/train_model.py +++ b/src/brevitas_examples/super_resolution/train_model.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import argparse +import copy from hashlib import sha256 import json import os @@ -46,10 +47,10 @@ parser.add_argument('--workers', type=int, default=0, help='Number of data loading workers') parser.add_argument('--batch_size', type=int, default=8, help='Minibatch size') parser.add_argument('--learning_rate', type=float, default=1e-3, help='Learning rate') -parser.add_argument('--total_epochs', type=int, default=100, help='Total number of training epochs') -parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay') +parser.add_argument('--total_epochs', type=int, default=500, help='Total number of training epochs') +parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay') parser.add_argument('--step_size', type=int, default=1) -parser.add_argument('--gamma', type=float, default=0.98) +parser.add_argument('--gamma', type=float, default=0.999) parser.add_argument('--eval_acc_bw', action='store_true', default=False) parser.add_argument('--save_pth_ckpt', action='store_true', default=False) parser.add_argument('--save_model_io', action='store_true', default=False) @@ -81,7 +82,6 @@ def main(): args.data_root, num_workers=args.workers, batch_size=args.batch_size, - batch_size_test=1, upscale_factor=model.upscale_factor, download=True) criterion = nn.MSELoss() @@ -92,17 +92,25 @@ def main(): scheduler = lrs.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) # train model + best_psnr, best_weights = 0., copy.deepcopy(model.state_dict()) for ep in range(args.total_epochs): train_loss = train_for_epoch(trainloader, model, criterion, optimizer) test_psnr = evaluate_avg_psnr(testloader, model) scheduler.step() print(f"[Epoch {ep:03d}] train_loss={train_loss:.4f}, test_psnr={test_psnr:.2f}") + if test_psnr >= best_psnr: + best_weights = copy.deepcopy(model.state_dict()) + best_psnr = test_psnr + model.load_state_dict(best_weights) + model = model.to(device) + test_psnr = evaluate_avg_psnr(testloader, model) + print(f"Final test_psnr={test_psnr:.2f}") # save checkpoint os.makedirs(args.save_path, exist_ok=True) if args.save_pth_ckpt: ckpt_path = f"{args.save_path}/{args.model}.pth" - torch.save(model.state_dict(), ckpt_path) + torch.save(best_weights, ckpt_path) with open(ckpt_path, "rb") as _file: bytes = _file.read() model_tag = sha256(bytes).hexdigest()[:8] diff --git a/src/brevitas_examples/super_resolution/utils/dataset.py b/src/brevitas_examples/super_resolution/utils/dataset.py index edadbb95a..cbf1004be 100644 --- a/src/brevitas_examples/super_resolution/utils/dataset.py +++ b/src/brevitas_examples/super_resolution/utils/dataset.py @@ -50,6 +50,9 @@ import torch.utils.data as data from torchvision.transforms import CenterCrop from torchvision.transforms import Compose +from torchvision.transforms import RandomCrop +from torchvision.transforms import RandomHorizontalFlip +from torchvision.transforms import RandomVerticalFlip from torchvision.transforms import Resize from torchvision.transforms import ToTensor @@ -79,21 +82,21 @@ def load_img_rbg(filepath): class DatasetFromFolder(data.Dataset): - def __init__(self, image_dir, input_transform=None, target_transform=None): + def __init__(self, image_dir, shared_transform, input_transform, target_transform): super(DatasetFromFolder, self).__init__() self.image_filenames = [ os.path.join(image_dir, x) for x in os.listdir(image_dir) if is_valid_image_file(x)] + self.shared_transform = shared_transform self.input_transform = input_transform self.target_transform = target_transform def __getitem__(self, index): input = load_img_rbg(self.image_filenames[index]) + input = self.shared_transform(input) target = input.copy() - if self.input_transform: - input = self.input_transform(input) - if self.target_transform: - target = self.target_transform(target) + input = self.input_transform(input) + target = self.target_transform(target) return input, target def __len__(self): @@ -122,42 +125,50 @@ def calculate_valid_crop_size(crop_size, upscale_factor): return crop_size - (crop_size % upscale_factor) +def train_transforms(crop_size): + return Compose([ + RandomCrop(crop_size, pad_if_needed=True), RandomHorizontalFlip(), RandomVerticalFlip()]) + + +def test_transforms(crop_size): + return Compose([CenterCrop(crop_size)]) + + def input_transform(crop_size, upscale_factor): return Compose([ - CenterCrop(crop_size), Resize(crop_size // upscale_factor), ToTensor(),]) -def target_transform(crop_size): - return Compose([ - CenterCrop(crop_size), - ToTensor(),]) +def target_transform(): + return Compose([ToTensor()]) -def get_training_set(upscale_factor: int, root_dir: str, crop_size: int = 256): +def get_training_set(upscale_factor: int, root_dir: str, crop_size: int): train_dir = os.path.join(root_dir, "train") crop_size = calculate_valid_crop_size(crop_size, upscale_factor) return DatasetFromFolder( train_dir, + shared_transform=train_transforms(crop_size), input_transform=input_transform(crop_size, upscale_factor), - target_transform=target_transform(crop_size)) + target_transform=target_transform()) -def get_test_set(upscale_factor: int, root_dir: str, crop_size: int = 256): +def get_test_set(upscale_factor: int, root_dir: str, crop_size: int): test_dir = os.path.join(root_dir, "test") crop_size = calculate_valid_crop_size(crop_size, upscale_factor) return DatasetFromFolder( test_dir, + shared_transform=test_transforms(crop_size), input_transform=input_transform(crop_size, upscale_factor), - target_transform=target_transform(crop_size)) + target_transform=target_transform()) def get_bsd300_dataloaders( data_root: str, num_workers: int = 0, batch_size: int = 32, - batch_size_test: int = 32, + batch_size_test: int = 100, pin_memory: bool = True, upscale_factor: int = 3, crop_size: int = 512, @@ -171,10 +182,10 @@ def get_bsd300_dataloaders( num_workers (int): Number of workers to use for both dataloaders. Default: 0 batch_size (int): Size of batches to use for the training dataloader. Default: 32 batch_size_test (int): Size of batches to use for the testing dataloader. When - None, then batch_size_test = batch_size. Default: 32 + None, then batch_size_test = batch_size. Default: 100 pin_memory (bool): Whether or not to pin the memory for both dataloaders. Default: True upscale_factor (int): The upscale factor for the super resolution task. Default: 3 - crop_size (int): The size to crop images for upscaling. Default 512 + crop_size (int): The size to crop images for upscaling. Default: 512 download (bool): Whether or not to download the dataset. Default: False """ data_root = download_bsd300(data_root, download) diff --git a/src/brevitas_examples/super_resolution/utils/evaluate.py b/src/brevitas_examples/super_resolution/utils/evaluate.py index 0319700e6..39c05a434 100644 --- a/src/brevitas_examples/super_resolution/utils/evaluate.py +++ b/src/brevitas_examples/super_resolution/utils/evaluate.py @@ -29,10 +29,6 @@ def _calc_min_acc_bit_width(module: QuantWBIOL) -> Tensor: def evaluate_accumulator_bit_widths(model: nn.Module, inp: Tensor): - if isinstance(model, QuantESPCN): - # Need to cache the quantized input to the final convolution to be able to evaluate the - # accumulator bounds since we need the input bit-width, which is specified at runtime. - model.conv4.conv.cache_inference_quant_inp = True model(inp) # collect quant inputs now that caching is enabled stats = dict() for name, module in model.named_modules(): diff --git a/src/brevitas_examples/super_resolution/utils/train.py b/src/brevitas_examples/super_resolution/utils/train.py index 5d12e8b19..94f9c36c3 100644 --- a/src/brevitas_examples/super_resolution/utils/train.py +++ b/src/brevitas_examples/super_resolution/utils/train.py @@ -4,6 +4,9 @@ import torch from torch import Tensor +from brevitas.core.scaling.pre_scaling import AccumulatorAwareParameterPreScaling +from brevitas.function import abs_binary_sign_grad + device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -14,17 +17,45 @@ def calc_average_psnr(ref_images: Tensor, gen_images: Tensor, eps: float = 1e-10 return psnr.mean() -def train_for_epoch(trainloader, model, criterion, optimizer): - tot_loss = 0. - for i, (images, targets) in enumerate(trainloader): +def train_for_epoch(trainloader, model, criterion, optimizer, reg_weight: float = 1e-3): + model.train() + + tot_loss, reg_penalty = 0., 0. + + def acc_reg_penalty(module: AccumulatorAwareParameterPreScaling, inp, output): + """Accumulate the regularization penalty across constrained layers""" + nonlocal reg_penalty + (weights, input_bit_width, input_is_signed) = inp + s = module.scaling_impl(weights) # s + g = abs_binary_sign_grad(module.restrict_clamp_scaling(module.value)) # g + T = module.get_upper_bound_on_l1_norm(input_bit_width, input_is_signed) # T / s + cur_penalty = torch.relu(g - (T * s)).sum() + reg_penalty += cur_penalty + return output + + # Register a forward hook to accumulate the regularization penalty + hook_fns = list() + for mod in model.modules(): + if isinstance(mod, AccumulatorAwareParameterPreScaling): + hook = mod.register_forward_hook(acc_reg_penalty) + hook_fns.append(hook) + + for _, (images, targets) in enumerate(trainloader): optimizer.zero_grad() images = images.to(device) targets = targets.to(device) outputs = model(images) - loss: Tensor = criterion(outputs, targets) + task_loss: Tensor = criterion(outputs, targets) + loss = task_loss + (reg_weight * reg_penalty) loss.backward() optimizer.step() - tot_loss += loss.item() * images.size(0) + reg_penalty = 0. # reset the accumulated regularization penalty + tot_loss += task_loss.item() * images.size(0) + + # Remove the registered forward hooks before exiting + for hook in hook_fns: + hook.remove() + avg_loss = tot_loss / len(trainloader.dataset) return avg_loss diff --git a/tests/brevitas_examples/test_examples_import.py b/tests/brevitas_examples/test_examples_import.py index 3f728428c..c80dd3161 100644 --- a/tests/brevitas_examples/test_examples_import.py +++ b/tests/brevitas_examples/test_examples_import.py @@ -1,6 +1,11 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import pytest + +from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant +from brevitas.quant.scaled_int import Int8WeightPerChannelFloat + def test_import_bnn_pynq(): from brevitas_examples.bnn_pynq import cnv_1w1a @@ -31,3 +36,17 @@ def test_import_stt(): from brevitas_examples.speech_to_text import quant_quartznet_perchannelscaling_4b from brevitas_examples.speech_to_text import quant_quartznet_perchannelscaling_8b from brevitas_examples.speech_to_text import quant_quartznet_pertensorscaling_8b + + +@pytest.mark.parametrize("upscale_factor", [2, 3, 4]) +@pytest.mark.parametrize("num_channels", [1, 3]) +@pytest.mark.parametrize( + "weight_quant", [Int8WeightPerChannelFloat, Int8AccumulatorAwareWeightQuant]) +def test_super_resolution_float_and_quant_models_match(upscale_factor, num_channels, weight_quant): + import brevitas.config as config + from brevitas_examples.super_resolution.models import float_espcn + from brevitas_examples.super_resolution.models import quant_espcn + config.IGNORE_MISSING_KEYS = True + float_model = float_espcn(upscale_factor, num_channels) + quant_model = quant_espcn(upscale_factor, num_channels, weight_quant=weight_quant) + quant_model.load_state_dict(float_model.state_dict())