From e9de10f49e39f7131911c7a6d590716ed1522ca3 Mon Sep 17 00:00:00 2001 From: icolbert Date: Wed, 13 Sep 2023 16:25:28 -0700 Subject: [PATCH] Pre-commit fixes --- .../super_resolution/models/__init__.py | 1 + .../super_resolution/models/common.py | 2 +- .../super_resolution/models/espcn.py | 25 ++++--------------- .../super_resolution/utils/dataset.py | 10 +++----- .../super_resolution/utils/train.py | 3 ++- 5 files changed, 13 insertions(+), 28 deletions(-) diff --git a/src/brevitas_examples/super_resolution/models/__init__.py b/src/brevitas_examples/super_resolution/models/__init__.py index 3ea12fea0..f3c9a0ad5 100644 --- a/src/brevitas_examples/super_resolution/models/__init__.py +++ b/src/brevitas_examples/super_resolution/models/__init__.py @@ -3,6 +3,7 @@ from functools import partial from typing import Union + from torch import hub import torch.nn as nn diff --git a/src/brevitas_examples/super_resolution/models/common.py b/src/brevitas_examples/super_resolution/models/common.py index f781f1487..d3022d089 100644 --- a/src/brevitas_examples/super_resolution/models/common.py +++ b/src/brevitas_examples/super_resolution/models/common.py @@ -24,7 +24,7 @@ class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat): scaling_per_output_channel = True -class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant): +class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant): pre_scaling_min_val = 1e-10 scaling_min_val = 1e-10 diff --git a/src/brevitas_examples/super_resolution/models/espcn.py b/src/brevitas_examples/super_resolution/models/espcn.py index 3d10e6581..f3123fcb9 100644 --- a/src/brevitas_examples/super_resolution/models/espcn.py +++ b/src/brevitas_examples/super_resolution/models/espcn.py @@ -14,12 +14,7 @@ from .common import ConstUint8ActQuant __all__ = [ - "float_espcn", - "quant_espcn", - "quant_espcn_a2q", - "quant_espcn_base", - "FloatESPCN", - "QuantESPCN"] + "float_espcn", "quant_espcn", "quant_espcn_a2q", "quant_espcn_base", "FloatESPCN", "QuantESPCN"] IO_DATA_BIT_WIDTH = 8 IO_ACC_BIT_WIDTH = 32 @@ -47,19 +42,9 @@ def __init__(self, upscale_factor: int = 3, num_channels: int = 3): padding=2, bias=True) self.conv2 = nn.Conv2d( - in_channels=64, - out_channels=64, - kernel_size=3, - stride=1, - padding=1, - bias=True) + in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True) self.conv3 = nn.Conv2d( - in_channels=64, - out_channels=32, - kernel_size=3, - stride=1, - padding=1, - bias=True) + in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, bias=True) self.conv4 = nn.Conv2d( in_channels=32, out_channels=num_channels * pow(upscale_factor, 2), @@ -85,7 +70,7 @@ def forward(self, inp: Tensor): x = self.relu(self.bn2(self.conv2(x))) x = self.relu(self.bn3(self.conv3(x))) x = self.pixel_shuffle(self.conv4(x)) - x = self.out(x) # To mirror quant version + x = self.out(x) # To mirror quant version return x @@ -145,7 +130,7 @@ def __init__( weight_quant=weight_quant) # We quantize the weights and input activations of the final layer # to 8-bit integers. We do not apply the accumulator constraint to - # the final convolution layer. FINN does not currently support + # the final convolution layer. FINN does not currently support # per-tensor quantization or biases for sub-pixel convolution layers. self.conv4 = qnn.QuantConv2d( in_channels=32, diff --git a/src/brevitas_examples/super_resolution/utils/dataset.py b/src/brevitas_examples/super_resolution/utils/dataset.py index a3ca12493..cbf1004be 100644 --- a/src/brevitas_examples/super_resolution/utils/dataset.py +++ b/src/brevitas_examples/super_resolution/utils/dataset.py @@ -50,11 +50,11 @@ import torch.utils.data as data from torchvision.transforms import CenterCrop from torchvision.transforms import Compose -from torchvision.transforms import Resize -from torchvision.transforms import ToTensor from torchvision.transforms import RandomCrop -from torchvision.transforms import RandomVerticalFlip from torchvision.transforms import RandomHorizontalFlip +from torchvision.transforms import RandomVerticalFlip +from torchvision.transforms import Resize +from torchvision.transforms import ToTensor __all__ = ["get_bsd300_dataloaders"] @@ -127,9 +127,7 @@ def calculate_valid_crop_size(crop_size, upscale_factor): def train_transforms(crop_size): return Compose([ - RandomCrop(crop_size, pad_if_needed=True), - RandomHorizontalFlip(), - RandomVerticalFlip()]) + RandomCrop(crop_size, pad_if_needed=True), RandomHorizontalFlip(), RandomVerticalFlip()]) def test_transforms(crop_size): diff --git a/src/brevitas_examples/super_resolution/utils/train.py b/src/brevitas_examples/super_resolution/utils/train.py index 24e859ac0..94f9c36c3 100644 --- a/src/brevitas_examples/super_resolution/utils/train.py +++ b/src/brevitas_examples/super_resolution/utils/train.py @@ -3,8 +3,9 @@ import torch from torch import Tensor -from brevitas.function import abs_binary_sign_grad + from brevitas.core.scaling.pre_scaling import AccumulatorAwareParameterPreScaling +from brevitas.function import abs_binary_sign_grad device = 'cuda' if torch.cuda.is_available() else 'cpu'