Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…into main
  • Loading branch information
Yu-Zhewen committed Jan 7, 2024
2 parents a2c0218 + 4b91b19 commit b9a2fe0
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 34 deletions.
26 changes: 18 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ python threshold_relu_example.py
* `cityscapes`: `unet`
* `llgmri`: `unet`
* `ucf101`: `x3d_s`, `x3d_m`
* `brats20`: `unet3d`
* `brats2020`: `unet3d`

## Quantization Results

Expand Down Expand Up @@ -62,6 +62,10 @@ python threshold_relu_example.py
| x3d_s | [mmaction2](https://github.com/open-mmlab/mmaction2) | 93.68 | 93.57 | 1.13 | 90.21 | 93.57 |
| x3d_m | [mmaction2](https://github.com/open-mmlab/mmaction2) | 96.40 | 96.40 | 0.81 | 95.24 | 96.29 |

### brats2020 (val, Dice coefficient)
| Model | Source | Float32 | Fixed16 | Fixed8 | BFP8 (Layer) | BFP8 (Channel) |
|-------|----------------------------------------------------------------|---------|---------|--------|--------------|----------------|
| unet3d | [BraTS20_3dUnet_3dAutoEncoder](https://www.kaggle.com/code/polomarco/brats20-3dunet-3dautoencoder) | 85.34 | 85.23 | 1.15 | 85.14 | 85.34 |

## Sparsity Results
* Q - Fixed16 Quantization
Expand All @@ -82,14 +86,20 @@ python threshold_relu_example.py
* RLE-8, run-length encoding, use 8 bits for encoding (max length 2^8)
* Compression Ratio, average over all weights and activations

| Dataset | Model | Experiment | Compression Ratio |
| Dataset | Model | Experiment | Avg Compression Ratio |
|------------|----------------------|------------|-------------------|
| coco | yolov8n ([onnx](https://drive.google.com/file/d/10-lNBid4VRzWBrE6GuT3I3L3H2BtWT1P/view?usp=sharing)) | RLE-8 | 1.753 |
| camvid | unet-bilinear ([onnx](https://drive.google.com/file/d/1C_Q58_NKMVfpbqg3ZbQ1IzyMSgoopex7/view?usp=sharing)) | RLE-8 | 1.175 |
| cityscapes | unet (onnx) | RLE-8 | GPU TIMEOUT |
| coco | yolov8n ([onnx](https://drive.google.com/file/d/1ghj2Da4HdkHSC-ADSe-JwvQbtwhUT_vT/view?usp=sharing)) | Huffman | 0.821 |
| camvid | unet-bilinear ([onnx](https://drive.google.com/file/d/1X6Ps_qcbP7vJLgNCkHbsHtWY6aSnG8es/view?usp=sharing)) | Huffman | 0.684 |
| cityscapes | unet ([onnx](https://drive.google.com/file/d/1d2v6VJI8B9DZY020Nq_AWQR0e8F9LH6A/view?usp=sharing)) | Huffman | 0.692 |
| coco | yolov8n ([onnx](https://drive.google.com/file/d/10-lNBid4VRzWBrE6GuT3I3L3H2BtWT1P/view?usp=sharing)) | RLE-8 | 1.753 |
| camvid | unet-bilinear ([onnx](https://drive.google.com/file/d/1C_Q58_NKMVfpbqg3ZbQ1IzyMSgoopex7/view?usp=sharing)) | RLE-8 | 1.175 |
| cityscapes | unet (onnx) | RLE-8 | GPU TIMEOUT |
| ucf101 | x3d_s ([onnx](https://drive.google.com/file/d/1gY5HGMWacbTQ5cK8MWdQgQ1lQM5VWRFb/view?usp=sharing)) | RLE-8 | 1.737 |
| ucf101 | x3d_m ([onnx](https://drive.google.com/file/d/1WaLjJYE0l_AiIrZw559Ile3xQza_wnWJ/view?usp=sharing)) | RLE-8 | 1.721 |
| brats2020 | unet3d (onnx) | RLE-8 | TBA |
| coco | yolov8n ([onnx](https://drive.google.com/file/d/1ghj2Da4HdkHSC-ADSe-JwvQbtwhUT_vT/view?usp=sharing)) | Huffman | 0.821 |
| camvid | unet-bilinear ([onnx](https://drive.google.com/file/d/1X6Ps_qcbP7vJLgNCkHbsHtWY6aSnG8es/view?usp=sharing)) | Huffman | 0.684 |
| cityscapes | unet ([onnx](https://drive.google.com/file/d/1d2v6VJI8B9DZY020Nq_AWQR0e8F9LH6A/view?usp=sharing)) | Huffman | 0.692 |
| ucf101 | x3d_s ([onnx](https://drive.google.com/file/d/19c6jwuHZVcfZXPpXMaGmaK9AsRXPO5lJ/view?usp=sharing)) | Huffman | 0.835 |
| ucf101 | x3d_m ([onnx](https://drive.google.com/file/d/1RQr0lEuROwO14F0WtObBUmuz8Na3Vci2/view?usp=sharing)) | Huffman | 0.833 |
| brats2020 | unet3d (onnx) | Huffman | TBA |

## Links to other repos
* Optimizer: https://github.com/AlexMontgomerie/fpgaconvnet-optimiser; https://github.com/AlexMontgomerie/samo
Expand Down
18 changes: 10 additions & 8 deletions encoding/huffman.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
import torch

from dahuffman import HuffmanCodec
from encoding.utils import convert_to_int, avg_compress_ratio
from quantization.utils import QuantMode, QuantAct, WEIGHT_QUANT_MODULES

from encoding.utils import (avg_compress_ratio, avg_compress_ratio_detailed,
convert_to_int)
from quantization.utils import WEIGHT_QUANT_MODULES, QuantAct, QuantMode


def get_huffman_encoding_ratio(count, x_width):
keys = [int(i) for i in range(-2**(x_width - 1), 2**(x_width - 1))]
hist = { keys[i]: int(count[i]) for i in range(len(keys)) }

codec = HuffmanCodec.from_frequencies(hist)
table = codec.get_code_table()

bits = [table[i][0] for i in keys]
avg_bits = count @ torch.tensor(bits, dtype=torch.float32, device=count.device) / torch.sum(count)
ratio = avg_bits.item() / x_width
return ratio

def log_hist_count(module, input, output):
quant_data = convert_to_int(input[0], module.word_length,
quant_data = convert_to_int(input[0], module.word_length,
module.scaling_factor, module.zero_point, (module.mode == QuantMode.CHANNEL_BFP))
count = torch.histc(quant_data, bins=2**module.word_length, min=-2**(module.word_length - 1), max=2**(module.word_length - 1)-1)
module.hist_count += count

def huffman_model(model_wrapper):
assert "quantization" in model_wrapper.sideband_info.keys(), "Only quantized models can be encoded"

weight_width = model_wrapper.sideband_info["quantization"]["weight_width"]
data_width = model_wrapper.sideband_info["quantization"]["data_width"]

Expand All @@ -40,7 +42,7 @@ def huffman_model(model_wrapper):
ratio = get_huffman_encoding_ratio(count, weight_width)
encode_info[name] = {"weight_compression_ratio": ratio}
elif isinstance(module, QuantAct):
assert name.endswith(".0") or name.endswith(".2")
assert name.endswith(".0") or name.endswith(".2")
module.hist_count = torch.zeros(2**data_width, dtype=torch.float32)
if torch.cuda.is_available():
module.hist_count = module.hist_count.cuda()
Expand All @@ -59,4 +61,4 @@ def huffman_model(model_wrapper):
assert False, "unexpected module name"

model_wrapper.sideband_info["encoding"] = encode_info
return avg_compress_ratio(encode_info)
return avg_compress_ratio(encode_info), avg_compress_ratio_detailed(encode_info)
19 changes: 10 additions & 9 deletions encoding/rle.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import torch

import torch.nn as nn

from encoding.utils import convert_to_int, avg_compress_ratio
from encoding.utils import (avg_compress_ratio, avg_compress_ratio_detailed,
convert_to_int)
from models.classification.utils import AverageMeter
from quantization.utils import QuantMode, QuantAct, WEIGHT_QUANT_MODULES
from quantization.utils import WEIGHT_QUANT_MODULES, QuantAct, QuantMode


def rle_compression_ratio(x, encoded_x, x_bits, l_bits):
assert len(x.shape) == 1, "x must be a 1D tensor"
Expand All @@ -31,20 +32,20 @@ def rle_encode(x_flatten, l_bits):
additional_elements.append([value, remain])
additional_elements = torch.tensor(additional_elements, device=x_flatten.device)
combined = torch.cat((combined[:index], additional_elements, combined[index + 1:]), dim=0)

return combined

def log_encoding(module, input, output):
batch_size = input[0].shape[0]
quant_data = convert_to_int(input[0], module.word_length,
quant_data = convert_to_int(input[0], module.word_length,
module.scaling_factor, module.zero_point, (module.mode == QuantMode.CHANNEL_BFP))
encoded_data = rle_encode(quant_data, module.l_bits)
ratio = rle_compression_ratio(quant_data, encoded_data, module.word_length, module.l_bits)
module.encoding_ratio.update(ratio, batch_size)

def rle_model(model_wrapper, l_bits):
assert "quantization" in model_wrapper.sideband_info.keys(), "Only quantized models can be encoded"

weight_width = model_wrapper.sideband_info["quantization"]["weight_width"]
data_width = model_wrapper.sideband_info["quantization"]["data_width"]

Expand All @@ -60,7 +61,7 @@ def rle_model(model_wrapper, l_bits):
ratio = rle_compression_ratio(quant_weight, encoded_weight, weight_width, l_bits)
encode_info[name] = {"weight_compression_ratio": ratio}
elif isinstance(module, QuantAct):
assert name.endswith(".0") or name.endswith(".2")
assert name.endswith(".0") or name.endswith(".2")
module.l_bits = l_bits
module.encoding_ratio = AverageMeter('compression ratio', ':6.3f')
module.register_forward_hook(log_encoding)
Expand All @@ -78,5 +79,5 @@ def rle_model(model_wrapper, l_bits):
assert False, "unexpected module name"

model_wrapper.sideband_info["encoding"] = encode_info
return avg_compress_ratio(encode_info)

return avg_compress_ratio(encode_info), avg_compress_ratio_detailed(encode_info)
21 changes: 18 additions & 3 deletions encoding/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import torch
from collections import defaultdict

import numpy as np
import torch

from quantization.utils import linear_quantize

def convert_to_int(x, word_length, scaling_factor, zero_point, transpose):

def convert_to_int(x, word_length, scaling_factor, zero_point, transpose):
if torch.cuda.is_available():
x = x.cuda()
scaling_factor = scaling_factor.cuda()
Expand All @@ -24,4 +28,15 @@ def avg_compress_ratio(encode_info):
for v in encode_info.values():
compression_ratio += list(v.values())
compression_ratio = np.mean(compression_ratio)
return compression_ratio
return compression_ratio

def avg_compress_ratio_detailed(encode_info):
sum_dict = defaultdict(float)
count_dict = defaultdict(int)

for v in encode_info.values():
for key, value in v.items():
sum_dict[key] += value
count_dict[key] += 1

return {key: sum_dict[key] / count_dict[key] for key in sum_dict}
21 changes: 15 additions & 6 deletions encoding_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import os
import pathlib
import random
from statistics import mean

import torch

from encoding.huffman import huffman_model
from encoding.rle import rle_model
from models import initialize_wrapper
from quantization.utils import QuantMode, quantize_model


def main():
parser = argparse.ArgumentParser(description='Quantization Example')
parser.add_argument('--dataset_name', default="camvid", type=str,
Expand All @@ -30,7 +33,7 @@ def main():

args = parser.parse_args()
if args.output_path == None:
args.output_path = os.path.join(os.getcwd(),
args.output_path = os.path.join(os.getcwd(),
f"output/{args.dataset_name}/{args.model_name}")
pathlib.Path(args.output_path).mkdir(parents=True, exist_ok=True)
print(args)
Expand All @@ -50,14 +53,20 @@ def main():
'weight_width': 8, 'data_width': 8, 'mode': QuantMode.CHANNEL_BFP})

# encoding
#print("Encoding model in RLE...")
#ratio = rle_model(model_wrapper, 8)
#print("compression ratio: ", ratio)
#model_wrapper.generate_onnx_files(os.path.join(args.output_path, "rle"))
# print("Encoding model in RLE...")
# ratio, ratio_detailed = rle_model(model_wrapper, 8)
# print("compression ratio: ", ratio)
# for k, v in ratio_detailed.items():
# print(f"{k}: {v}")
# print("compression ratio (detailed): ", mean(ratio_detailed.values()))
# model_wrapper.generate_onnx_files(os.path.join(args.output_path, "rle"))

print("Encoding model in Huffman...")
ratio = huffman_model(model_wrapper)
ratio, ratio_detailed = huffman_model(model_wrapper)
print("compression ratio: ", ratio)
for k, v in ratio_detailed.items():
print(f"{k}: {v}")
print("compression ratio (detailed): ", mean(ratio_detailed.values()))
model_wrapper.generate_onnx_files(os.path.join(args.output_path, "huffman"))

if __name__ == '__main__':
Expand Down

0 comments on commit b9a2fe0

Please sign in to comment.