-
Notifications
You must be signed in to change notification settings - Fork 2
/
quantization_example.py
93 lines (76 loc) · 3.3 KB
/
quantization_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
import os
import pathlib
import random
import torch
from models import initialize_wrapper
from quantization.utils import QuantMode, quantize_model
def main():
parser = argparse.ArgumentParser(description='Quantization Example')
parser.add_argument('--dataset_name', default="imagenet", type=str,
help='dataset name')
parser.add_argument('--dataset_path', metavar='DIR', default="~/dataset/ILSVRC2012_img",
help='path to dataset')
parser.add_argument('--model_name', metavar='ARCH', default='resnet18',
help='model architecture')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers')
parser.add_argument('-b', '--batch-size', default=64, type=int, metavar='N',
help='mini-batch size')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use.')
parser.add_argument('--output_path', default=None, type=str,
help='output path')
args = parser.parse_args()
if args.output_path == None:
args.output_path = os.path.join(os.getcwd(),
f"output/{args.dataset_name}/{args.model_name}")
pathlib.Path(args.output_path).mkdir(parents=True, exist_ok=True)
print(args)
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
random.seed(0)
torch.manual_seed(0)
model_wrapper = initialize_wrapper(args.dataset_name, args.model_name,
os.path.expanduser(args.dataset_path), args.batch_size, args.workers)
# TEST 1
print("FLOAT32 Inference")
model_wrapper.inference("test")
model_wrapper.generate_onnx_files(
os.path.join(args.output_path, "float32"))
# TEST 2
print("NETWORK FP16 Inference")
# reload the model everytime a new quantization mode is tested
model_wrapper.load_model()
quantize_model(model_wrapper, {
'weight_width': 16, 'data_width': 16, 'mode': QuantMode.NETWORK_FP})
model_wrapper.inference("test")
model_wrapper.generate_onnx_files(
os.path.join(args.output_path, "network_fp16"))
# TEST 3
print("NETWORK FP8 Inference")
model_wrapper.load_model()
quantize_model(model_wrapper, {
'weight_width': 8, 'data_width': 8, 'mode': QuantMode.NETWORK_FP})
model_wrapper.inference("test")
model_wrapper.generate_onnx_files(
os.path.join(args.output_path, "network_fp8"))
# TEST 4
print("LAYER BFP8 Inference")
model_wrapper.load_model()
quantize_model(model_wrapper, {
'weight_width': 8, 'data_width': 8, 'mode': QuantMode.LAYER_BFP})
model_wrapper.inference("test")
model_wrapper.generate_onnx_files(
os.path.join(args.output_path, "layer_bfp8"))
# TEST 5
print("CHANNEL BFP8 Inference")
# note: CHANNEL_BFP can be worse than LAYER_BFP, if calibration size is small!
model_wrapper.load_model()
quantize_model(model_wrapper, {
'weight_width': 8, 'data_width': 8, 'mode': QuantMode.CHANNEL_BFP})
model_wrapper.inference("test")
model_wrapper.generate_onnx_files(
os.path.join(args.output_path, "channel_bfp8"))
if __name__ == '__main__':
main()