-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a new transformation pass for approximating ConvTranspose with Up…
…sampling + PointwiseConv
- Loading branch information
Showing
3 changed files
with
139 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import torch | ||
from torch import nn | ||
import math | ||
|
||
CONV_TRANSP_MODULES = ( | ||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d) | ||
|
||
class ConvTranspApproxLayer(nn.Module): | ||
def __init__(self, parent_module, upsampling_mode, kernel_approx_strategy): | ||
super().__init__() | ||
self.device = parent_module.weight.device | ||
self.has_bias = True if isinstance(parent_module.bias, nn.Parameter) else False | ||
|
||
self.upsampling_mode = upsampling_mode | ||
self.kernel_approx_strategy = kernel_approx_strategy | ||
|
||
# TODO: Here we assume a random value for the input spatial dimension (1st dim). We also assume that the stride, padding, kernel_size and output_padding are the same for all dimensions. If this is not the case, we need to change the code below. | ||
rand_spatial_dim = 128 | ||
self.scale_factor = math.ceil(((rand_spatial_dim - 1) * parent_module.stride[0] - 2 * parent_module.padding[0] + | ||
parent_module.kernel_size[0] + parent_module.output_padding[0]) / rand_spatial_dim) | ||
|
||
self.upsample = nn.Upsample( | ||
scale_factor=self.scale_factor, mode=self.upsampling_mode).to(self.device) | ||
match parent_module._get_name(): | ||
case "ConvTranspose1d": | ||
self.pointwise_conv = nn.Conv1d( | ||
in_channels=parent_module.in_channels, out_channels=parent_module.out_channels, kernel_size=1, bias=self.has_bias).to(self.device) | ||
|
||
if self.kernel_approx_strategy == "average": | ||
weights = parent_module.weight.data.permute(1, 0, 2) | ||
weights = torch.mean(weights, dim=(2), keepdim=True) | ||
else: | ||
raise NotImplementedError | ||
|
||
case "ConvTranspose2d": | ||
self.pointwise_conv = nn.Conv2d( | ||
in_channels=parent_module.in_channels, out_channels=parent_module.out_channels, kernel_size=1, bias=self.has_bias).to(self.device) | ||
|
||
if self.kernel_approx_strategy == "average": | ||
weights = parent_module.weight.data.permute(1, 0, 2, 3) | ||
weights = torch.mean(weights, dim=(2, 3), keepdim=True) | ||
else: | ||
raise NotImplementedError | ||
|
||
case "ConvTranspose3d": | ||
self.pointwise_conv = nn.Conv3d( | ||
in_channels=parent_module.in_channels, out_channels=parent_module.out_channels, kernel_size=1, bias=self.has_bias).to(self.device) | ||
|
||
if self.kernel_approx_strategy == "average": | ||
weights = parent_module.weight.data.permute(1, 0, 2, 3, 4) | ||
weights = torch.mean(weights, dim=(2, 3, 4), keepdim=True) | ||
else: | ||
raise NotImplementedError | ||
|
||
self.pointwise_conv.weight.data.copy_(weights) | ||
if self.has_bias: | ||
self.pointwise_conv.bias.data.copy_(parent_module.bias.data) | ||
|
||
def forward(self, x): | ||
x = self.upsample(x) | ||
x = self.pointwise_conv(x) | ||
|
||
return x | ||
|
||
|
||
def apply_conv_transp_approx(model_wrapper, upsampling_mode, kernel_approx_strategy): | ||
replace_dict = {} | ||
for name, module in model_wrapper.named_modules(): | ||
if isinstance(module, CONV_TRANSP_MODULES): | ||
new_module = ConvTranspApproxLayer( | ||
parent_module=module, upsampling_mode=upsampling_mode, kernel_approx_strategy=kernel_approx_strategy) | ||
replace_dict[module] = new_module | ||
model_wrapper.replace_modules(replace_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import argparse | ||
import copy | ||
import os | ||
import pathlib | ||
import random | ||
import torch | ||
|
||
from models import initialize_wrapper | ||
from quantization.utils import QuantMode, quantize_model | ||
from conv_transp_approx.utils import apply_conv_transp_approx | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='Transpose Convolution Approximation Example') | ||
parser.add_argument('--dataset_name', default="camvid", type=str, | ||
help='dataset name') | ||
parser.add_argument('--dataset_path', metavar='DIR', default="~/dataset/ILSVRC2012_img", | ||
help='path to dataset') | ||
parser.add_argument('--model_name', metavar='ARCH', default='unet', | ||
help='model architecture') | ||
|
||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', | ||
help='number of data loading workers') | ||
parser.add_argument('-b', '--batch-size', default=16, type=int, metavar='N', | ||
help='mini-batch size') | ||
parser.add_argument('--gpu', default=None, type=int, | ||
help='GPU id to use.') | ||
|
||
parser.add_argument('--output_path', default=None, type=str, | ||
help='output path') | ||
|
||
|
||
args = parser.parse_args() | ||
if args.output_path == None: | ||
args.output_path = os.getcwd() + "/output" | ||
pathlib.Path(args.output_path).mkdir(parents=True, exist_ok=True) | ||
print(args) | ||
|
||
if args.gpu is not None: | ||
torch.cuda.set_device(args.gpu) | ||
|
||
random.seed(0) | ||
torch.manual_seed(0) | ||
|
||
model_wrapper = initialize_wrapper(args.dataset_name, args.model_name, | ||
os.path.expanduser(args.dataset_path), args.batch_size, args.workers) | ||
|
||
# TEST 1 | ||
print("FLOAT32 Inference") | ||
model_wrapper.inference("test") | ||
model_wrapper.generate_onnx_files(os.path.join(args.output_path, "fp32")) | ||
|
||
|
||
# TEST 12 | ||
apply_conv_transp_approx(model_wrapper=model_wrapper, upsampling_mode="bilinear", kernel_approx_strategy="average") | ||
print("FLOAT32 Inference Conv Transpose Approximation") | ||
model_wrapper.inference("test") | ||
# FIXME: if we use generate_onnx_files here the onnx saved model does not contain the changes made by apply_conv_transp_approx | ||
model_wrapper.onnx_exporter(os.path.join(args.output_path, "fp32_approx", f"{args.model_name}_f32.onnx")) | ||
|
||
if __name__ == '__main__': | ||
main() | ||
|