Skip to content

Commit

Permalink
update bilinear upsampling results
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu-Zhewen committed Nov 14, 2023
1 parent fe84e15 commit fd8942e
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 77 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ bash scripts/run_quantization.sh
| yolov8n | [ultralytics](https://github.com/ultralytics/ultralytics) | 37.1 | 37.1 | 0.0 | 0.0 | 35.1 |

### camvid (val, mIOU)
| Model | Source | Float32 | Fixed16 | Fixed8 | BFP8 (Layer) | BFP8 (Channel) |
|-------|-------------------------------------------------|---------|---------|--------|--------------|----------------|
| unet | [nncf](https://github.com/openvinotoolkit/nncf) | 71.95 | 71.95 | 61.02 | 71.60 | 71.85 |
| unet-approx | [nncf](https://github.com/openvinotoolkit/nncf) | 71.67 | - | - | - | - |
| Model | Source | Float32 | Fixed16 | Fixed8 | BFP8 (Layer) | BFP8 (Channel) |
|---------------|-------------------------------------------------|---------|---------|--------|--------------|----------------|
| unet | [nncf](https://github.com/openvinotoolkit/nncf) | 71.95 | 71.95 | 61.02 | 71.60 | 71.85 |
| unet-bilinear | [nncf](https://github.com/openvinotoolkit/nncf) | 71.67 | 71.67 | 60.62 | 71.40 | 71.75 |

### cityscapes (val, mIOU)
| Model | Source | Float32 | Fixed16 | Fixed8 | BFP8 (Layer) | BFP8 (Channel) |
Expand Down
5 changes: 4 additions & 1 deletion models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,8 @@ def onnx_exporter(self):
def forward(self, x):
return self.model(x)

from models.utils import replace_modules
def replace_modules(self, replace_dict):
from models.utils import replace_modules
replace_modules(self.model, replace_dict)

from models.utils import generate_onnx_files
7 changes: 6 additions & 1 deletion models/segmentation/camvid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from collections import OrderedDict
from models.base import TorchModelWrapper
from models.segmentation.utils import apply_conv_transp_approx
from PIL import Image
from torch.utils import data

Expand All @@ -18,7 +19,7 @@ def __init__(self, model_name, input_size=(1, 3, 368, 480), num_classes=12):
self.num_classes = num_classes
super().__init__(model_name)

def load_model(self, eval=True):
def load_model(self, eval=True, approx_transpose_conv=True):
assert self.model_name == 'unet'

self.model = UNet(input_size_hw=self.input_size[2:], in_channels=self.input_size[1], n_classes=self.num_classes)
Expand All @@ -27,6 +28,10 @@ def load_model(self, eval=True):
# remove 'module.' prefix
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
self.model.load_state_dict(state_dict)

if approx_transpose_conv:
apply_conv_transp_approx(self.model)

if torch.cuda.is_available():
self.model = self.model.cuda()

Expand Down
14 changes: 7 additions & 7 deletions conv_transp_approx/utils.py → models/segmentation/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import math
import torch

from models.utils import replace_modules
from torch import nn
import math

CONV_TRANSP_MODULES = (
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)
CONV_TRANSP_MODULES = (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)

class ConvTranspApproxLayer(nn.Module):
def __init__(self, parent_module, upsampling_mode, kernel_approx_strategy):
Expand Down Expand Up @@ -62,12 +63,11 @@ def forward(self, x):

return x


def apply_conv_transp_approx(model_wrapper, upsampling_mode, kernel_approx_strategy):
def apply_conv_transp_approx(model, upsampling_mode="bilinear", kernel_approx_strategy="average"):
replace_dict = {}
for name, module in model_wrapper.named_modules():
for name, module in model.named_modules():
if isinstance(module, CONV_TRANSP_MODULES):
new_module = ConvTranspApproxLayer(
parent_module=module, upsampling_mode=upsampling_mode, kernel_approx_strategy=kernel_approx_strategy)
replace_dict[module] = new_module
model_wrapper.replace_modules(replace_dict)
replace_modules(model, replace_dict)
4 changes: 2 additions & 2 deletions models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import torch.nn as nn

def replace_modules(self, replace_dict):
for name, module in self.model.named_modules():
def replace_modules(model, replace_dict):
for name, module in model.named_modules():
for subname, submodule in module.named_children():
if submodule in replace_dict.keys():
new_submodule = replace_dict[submodule]
Expand Down
62 changes: 0 additions & 62 deletions transpose_conv_approximation_example.py

This file was deleted.

0 comments on commit fd8942e

Please sign in to comment.