Skip to content

Commit

Permalink
update encoding results
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu-Zhewen committed Dec 20, 2023
1 parent 216d7f9 commit b54e4fd
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 17 deletions.
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ python threshold_relu_example.py
* WS - Weight Sparsity (applying global pruning threshold)
* Post-training, without fine-tuning

### imagenet

| Model | Experiment | Accuracy | Sparsity |
|----------|----------------|----------|----------|
| resnet18 | Q+AS | 69.74 | 50.75 |
Expand All @@ -79,6 +77,17 @@ python threshold_relu_example.py
| resnet18 | Q+AS+WS(0.015) | 58.38 | 65.91 |
| resnet18 | Q+AS+WS(0.020) | 27.91 | 69.63 |

## Encoding Results
* BFP8 (Channel) Quantization
* RLE-8, run-length encoding, use 8 bits for encoding (max length 2^8)
* Compression Ratio, average over all weights and activations

| Dataset | Model | Experiment | Compression Ratio |
|------------|----------------------|------------|-------------------|
| coco | yolov8n ([onnx](https://drive.google.com/file/d/10-lNBid4VRzWBrE6GuT3I3L3H2BtWT1P/view?usp=sharing)) | RLE-8 | 1.753 |
| camvid | unet-bilinear ([onnx](https://drive.google.com/file/d/1C_Q58_NKMVfpbqg3ZbQ1IzyMSgoopex7/view?usp=sharing)) | RLE-8 | 1.175 |
| cityscapes | unet (onnx) | RLE-8 | |

## Links to other repos
* Optimizer: https://github.com/AlexMontgomerie/fpgaconvnet-optimiser; https://github.com/AlexMontgomerie/samo
* Model: https://github.com/AlexMontgomerie/fpgaconvnet-model
Expand Down
12 changes: 12 additions & 0 deletions encoding/rle.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import torch

import numpy as np
import torch.nn as nn

from models.classification.utils import AverageMeter
Expand All @@ -13,6 +15,10 @@ def get_compression_ratio(x, encoded_x, x_bits, l_bits):
return (len(encoded_x) * (x_bits + l_bits)) / (len(x.flatten()) * x_bits)

def encode(x, word_length, scaling_factor, zero_point, l_bits, transpose=False):
if torch.cuda.is_available():
x = x.cuda()
scaling_factor = scaling_factor.cuda()
zero_point = zero_point.cuda()
# convert to quantized int representation
if transpose:
x = x.transpose(0, 1)
Expand Down Expand Up @@ -89,3 +95,9 @@ def encode_model(model_wrapper, l_bits):
assert False, "unexpected module name"

model_wrapper.sideband_info["encoding"] = encode_info

compression_ratio = []
for v in encode_info.values():
compression_ratio += list(v.values())
compression_ratio = np.mean(compression_ratio)
return compression_ratio
6 changes: 4 additions & 2 deletions encoding_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def main():

args = parser.parse_args()
if args.output_path == None:
args.output_path = os.getcwd() + "/output"
args.output_path = os.path.join(os.getcwd(),
f"output/{args.dataset_name}/{args.model_name}")
pathlib.Path(args.output_path).mkdir(parents=True, exist_ok=True)
print(args)

Expand All @@ -48,7 +49,8 @@ def main():
'weight_width': 8, 'data_width': 8, 'mode': QuantMode.CHANNEL_BFP})

# encoding
encode_model(model_wrapper, 8)
ratio = encode_model(model_wrapper, 8)
print("Encoding ratio: ", ratio)
model_wrapper.generate_onnx_files(os.path.join(args.output_path, "encode"))


Expand Down
23 changes: 19 additions & 4 deletions models/detection/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import onnx_graphsurgeon as gs
import torch.nn as nn

from models.base import TorchModelWrapper
# note: do NOT move ultralytic import to the top, otherwise the edit in settings will not take effect
Expand All @@ -16,15 +17,19 @@ def load_model(self, eval=True):
from ultralytics import YOLO
self.yolo = YOLO(self.model_name)
self.model = self.yolo.model
#if torch.cuda.is_available():
# self.model = self.model.cuda()
self.model_fixer()

# utlralytics conv bn fusion is currently not working for compressed model
# disbale it for now
# utlralytics conv bn fusion not working after compression, disable it
def _fuse(verbose=True):
return self.model
self.model.fuse = _fuse

def model_fixer(self):
from ultralytics.nn.modules import Conv
for name, module in self.named_modules():
if isinstance(module, Conv) and isinstance(module.act, nn.SiLU):
module.act = nn.Hardswish(inplace=True)

def load_data(self, batch_size, workers):
from ultralytics import settings

Expand Down Expand Up @@ -53,6 +58,16 @@ def onnx_exporter(self, onnx_path):
os.rename(path, onnx_path)
self.remove_detection_head_v8(onnx_path)

# rename sideband info
for method, info in self.sideband_info.items():
new_info = {}
for k, v in info.items():
if k.startswith("yolo.model."):
new_info[k.replace("yolo.model.", "")] = v
else:
new_info[k] = v
self.sideband_info[method] = new_info

def remove_detection_head_v8(self, onnx_path):
graph = onnx.load(onnx_path)
graph = gs.import_onnx(graph)
Expand Down
21 changes: 17 additions & 4 deletions models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,32 @@ def _annotate_encoding(onnx_model, sideband_info):
continue
inputs = node.input
outputs = node.output
if node.op_type == 'Resize':
if node.op_type in ['Resize', 'Split']:
inputs = [node.input[0]]
input_compression_ratio = []
for input_name in inputs:
p_node = find_producer(onnx_model.graph, input_name)
if p_node == None:
input_compression_ratio.append(1.0) # todo: fix missing info
continue
p_name = onnx_to_torch_name_cast(p_node.name, p_node.op_type)
input_compression_ratio.append(info[p_name]["output_compression_ratio"])
if p_name in info.keys():
input_compression_ratio.append(info[p_name]["output_compression_ratio"])
else:
input_compression_ratio.append(1.0) # todo: fix missing info
set_nodeattr(node, "input_compression_ratio", input_compression_ratio)
output_compression_ratio = []
for output_name in outputs:
c_node = find_consumers(onnx_model.graph, output_name)[0]
c_node = find_consumers(onnx_model.graph, output_name)
if len(c_node) == 0:
output_compression_ratio.append(1.0) # todo: fix missing info
continue
c_node = c_node[0]
c_name = onnx_to_torch_name_cast(c_node.name, c_node.op_type)
output_compression_ratio.append(info[c_name]["input_compression_ratio"])
if c_name in info.keys():
output_compression_ratio.append(info[c_name]["input_compression_ratio"])
else:
output_compression_ratio.append(1.0) # todo: fix missing info
set_nodeattr(node, "output_compression_ratio", output_compression_ratio)


Expand Down
10 changes: 8 additions & 2 deletions quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,14 @@ class QuantMode(Enum):
LAYER_BFP = 2
CHANNEL_BFP = 3

ACTIVA_QUANT_MODULES = (nn.Conv2d, nn.Conv3d, nn.Linear, nn.ConvTranspose2d, nn.ConvTranspose3d, nn.ReLU, nn.ReLU6, nn.MaxPool2d, nn.MaxPool3d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d, nn.AvgPool2d, nn.AvgPool3d)
WEIGHT_QUANT_MODULES = (nn.Conv2d, nn.Conv3d, nn.Linear, nn.ConvTranspose2d, nn.ConvTranspose3d)
ACTIVA_QUANT_MODULES = (nn.Conv2d, nn.Conv3d, nn.Linear,
nn.ConvTranspose2d, nn.ConvTranspose3d,
nn.ReLU, nn.ReLU6, nn.LeakyReLU, nn.Hardswish,
nn.MaxPool2d, nn.MaxPool3d,
nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d,
nn.AvgPool2d, nn.AvgPool3d)
WEIGHT_QUANT_MODULES = (nn.Conv2d, nn.Conv3d, nn.Linear,
nn.ConvTranspose2d, nn.ConvTranspose3d)

def linear_quantize(x, scaling_factor, zero_point):
if len(x.shape) == 5:
Expand Down
3 changes: 2 additions & 1 deletion quantization_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def main():

args = parser.parse_args()
if args.output_path == None:
args.output_path = os.getcwd() + "/output"
args.output_path = os.path.join(os.getcwd(),
f"output/{args.dataset_name}/{args.model_name}")
pathlib.Path(args.output_path).mkdir(parents=True, exist_ok=True)
print(args)

Expand Down
3 changes: 2 additions & 1 deletion sparsity_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def main():

args = parser.parse_args()
if args.output_path == None:
args.output_path = os.getcwd() + "/output"
args.output_path = os.path.join(os.getcwd(),
f"output/{args.dataset_name}/{args.model_name}")
pathlib.Path(args.output_path).mkdir(parents=True, exist_ok=True)
print(args)

Expand Down
3 changes: 2 additions & 1 deletion threshold_relu_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def main():

args = parser.parse_args()
if args.output_path == None:
args.output_path = os.getcwd() + "/output"
args.output_path = os.path.join(os.getcwd(),
f"output/{args.dataset_name}/{args.model_name}")
pathlib.Path(args.output_path).mkdir(parents=True, exist_ok=True)
print(args)

Expand Down

0 comments on commit b54e4fd

Please sign in to comment.