Skip to content

Commit

Permalink
Add a new transformation pass for approximating ConvTranspose with Up…
Browse files Browse the repository at this point in the history
…sampling + PointwiseConv
  • Loading branch information
ptoupas committed Nov 14, 2023
1 parent 5dbeddf commit fe84e15
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 3 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ python threshold_relu_example.py
* `camvid`: `unet`
* `cityscapes`: `unet`

## Quantization Results
## Quantization Results
@ commit ec09e56
```
bash scripts/run_quantization.sh
Expand All @@ -44,8 +44,9 @@ bash scripts/run_quantization.sh
| Model | Source | Float32 | Fixed16 | Fixed8 | BFP8 (Layer) | BFP8 (Channel) |
|-------|-------------------------------------------------|---------|---------|--------|--------------|----------------|
| unet | [nncf](https://github.com/openvinotoolkit/nncf) | 71.95 | 71.95 | 61.02 | 71.60 | 71.85 |
| unet-approx | [nncf](https://github.com/openvinotoolkit/nncf) | 71.67 | - | - | - | - |

### cityscapes (val, mIOU)
### cityscapes (val, mIOU)
| Model | Source | Float32 | Fixed16 | Fixed8 | BFP8 (Layer) | BFP8 (Channel) |
|-------|----------------------------------------------------------------|---------|---------|--------|--------------|----------------|
| unet | [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) | 69.10 | 69.10 | 1.98 | 61.74 | 68.43 |
Expand All @@ -54,4 +55,4 @@ bash scripts/run_quantization.sh
* Optimizer: https://github.com/AlexMontgomerie/fpgaconvnet-optimiser; https://github.com/AlexMontgomerie/samo
* Model: https://github.com/AlexMontgomerie/fpgaconvnet-model
* HLS: https://github.com/AlexMontgomerie/fpgaconvnet-hls
* Tutorial: https://github.com/AlexMontgomerie/fpgaconvnet-tutorial
* Tutorial: https://github.com/AlexMontgomerie/fpgaconvnet-tutorial
73 changes: 73 additions & 0 deletions conv_transp_approx/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
from torch import nn
import math

CONV_TRANSP_MODULES = (
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)

class ConvTranspApproxLayer(nn.Module):
def __init__(self, parent_module, upsampling_mode, kernel_approx_strategy):
super().__init__()
self.device = parent_module.weight.device
self.has_bias = True if isinstance(parent_module.bias, nn.Parameter) else False

self.upsampling_mode = upsampling_mode
self.kernel_approx_strategy = kernel_approx_strategy

# TODO: Here we assume a random value for the input spatial dimension (1st dim). We also assume that the stride, padding, kernel_size and output_padding are the same for all dimensions. If this is not the case, we need to change the code below.
rand_spatial_dim = 128
self.scale_factor = math.ceil(((rand_spatial_dim - 1) * parent_module.stride[0] - 2 * parent_module.padding[0] +
parent_module.kernel_size[0] + parent_module.output_padding[0]) / rand_spatial_dim)

self.upsample = nn.Upsample(
scale_factor=self.scale_factor, mode=self.upsampling_mode).to(self.device)
match parent_module._get_name():
case "ConvTranspose1d":
self.pointwise_conv = nn.Conv1d(
in_channels=parent_module.in_channels, out_channels=parent_module.out_channels, kernel_size=1, bias=self.has_bias).to(self.device)

if self.kernel_approx_strategy == "average":
weights = parent_module.weight.data.permute(1, 0, 2)
weights = torch.mean(weights, dim=(2), keepdim=True)
else:
raise NotImplementedError

case "ConvTranspose2d":
self.pointwise_conv = nn.Conv2d(
in_channels=parent_module.in_channels, out_channels=parent_module.out_channels, kernel_size=1, bias=self.has_bias).to(self.device)

if self.kernel_approx_strategy == "average":
weights = parent_module.weight.data.permute(1, 0, 2, 3)
weights = torch.mean(weights, dim=(2, 3), keepdim=True)
else:
raise NotImplementedError

case "ConvTranspose3d":
self.pointwise_conv = nn.Conv3d(
in_channels=parent_module.in_channels, out_channels=parent_module.out_channels, kernel_size=1, bias=self.has_bias).to(self.device)

if self.kernel_approx_strategy == "average":
weights = parent_module.weight.data.permute(1, 0, 2, 3, 4)
weights = torch.mean(weights, dim=(2, 3, 4), keepdim=True)
else:
raise NotImplementedError

self.pointwise_conv.weight.data.copy_(weights)
if self.has_bias:
self.pointwise_conv.bias.data.copy_(parent_module.bias.data)

def forward(self, x):
x = self.upsample(x)
x = self.pointwise_conv(x)

return x


def apply_conv_transp_approx(model_wrapper, upsampling_mode, kernel_approx_strategy):
replace_dict = {}
for name, module in model_wrapper.named_modules():
if isinstance(module, CONV_TRANSP_MODULES):
new_module = ConvTranspApproxLayer(
parent_module=module, upsampling_mode=upsampling_mode, kernel_approx_strategy=kernel_approx_strategy)
replace_dict[module] = new_module
model_wrapper.replace_modules(replace_dict)
62 changes: 62 additions & 0 deletions transpose_conv_approximation_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import argparse
import copy
import os
import pathlib
import random
import torch

from models import initialize_wrapper
from quantization.utils import QuantMode, quantize_model
from conv_transp_approx.utils import apply_conv_transp_approx

def main():
parser = argparse.ArgumentParser(description='Transpose Convolution Approximation Example')
parser.add_argument('--dataset_name', default="camvid", type=str,
help='dataset name')
parser.add_argument('--dataset_path', metavar='DIR', default="~/dataset/ILSVRC2012_img",
help='path to dataset')
parser.add_argument('--model_name', metavar='ARCH', default='unet',
help='model architecture')

parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers')
parser.add_argument('-b', '--batch-size', default=16, type=int, metavar='N',
help='mini-batch size')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use.')

parser.add_argument('--output_path', default=None, type=str,
help='output path')


args = parser.parse_args()
if args.output_path == None:
args.output_path = os.getcwd() + "/output"
pathlib.Path(args.output_path).mkdir(parents=True, exist_ok=True)
print(args)

if args.gpu is not None:
torch.cuda.set_device(args.gpu)

random.seed(0)
torch.manual_seed(0)

model_wrapper = initialize_wrapper(args.dataset_name, args.model_name,
os.path.expanduser(args.dataset_path), args.batch_size, args.workers)

# TEST 1
print("FLOAT32 Inference")
model_wrapper.inference("test")
model_wrapper.generate_onnx_files(os.path.join(args.output_path, "fp32"))


# TEST 12
apply_conv_transp_approx(model_wrapper=model_wrapper, upsampling_mode="bilinear", kernel_approx_strategy="average")
print("FLOAT32 Inference Conv Transpose Approximation")
model_wrapper.inference("test")
# FIXME: if we use generate_onnx_files here the onnx saved model does not contain the changes made by apply_conv_transp_approx
model_wrapper.onnx_exporter(os.path.join(args.output_path, "fp32_approx", f"{args.model_name}_f32.onnx"))

if __name__ == '__main__':
main()

0 comments on commit fe84e15

Please sign in to comment.